<a href="https://colab.research.google.com/github/satvik-venkatesh/Wave-U-net-TF2/blob/main/Wave-U-net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Layer


In [None]:
i = tf.keras.layers.Input(shape=(1, 128, 128, 1))
X = i
#X = tf.keras.layers.Conv2DTranspose(filters=16, kernel_size=3)(X)
X = SimpleDense()(X)
o = tf.keras.layers.Dense(1)(X)
model = tf.keras.Model(inputs=i, outputs=o)
model.summary()

tf.shape(self.w): [ 1 32]
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         [(None, 1, 128, 128, 1)]  0         
_________________________________________________________________
simple_dense_2 (SimpleDense) (None, 1, 128, 128, 32)   64        
_________________________________________________________________
dense_2 (Dense)              (None, 1, 128, 128, 1)    33        
Total params: 97
Trainable params: 97
Non-trainable params: 0
_________________________________________________________________


In [None]:
# inherit from this base class
from tensorflow.keras.layers import Layer

class SimpleDense(Layer):

    def __init__(self, units=32):
        '''Initializes the instance attributes'''
        super(SimpleDense, self).__init__()
        self.units = units

    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        # initialize the weights
        w_init = tf.random_normal_initializer()

        self.w = tf.Variable(name="kernel",
            initial_value=w_init(shape=(input_shape[-1], self.units),
                                 dtype='float32'),
            trainable=True)
        
        print("tf.shape(self.w): {}".format(tf.shape(self.w)))

        # initialize the biases
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(name="bias",
            initial_value=b_init(shape=(self.units,), dtype='float32'),
            trainable=True)

    def call(self, inputs):
        '''Defines the computation from inputs to outputs'''
        return tf.matmul(inputs, self.w) + self.b

In [None]:
# inherit from this base class
from tensorflow.keras.layers import Layer

class AudioClipLayer(Layer):

    def __init__(self, units=32, **kwargs):
        '''Initializes the instance attributes'''
        super(AudioClipLayer, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        # initialize the weights
        pass
        
    def call(self, inputs, training):
        '''Defines the computation from inputs to outputs'''
        if training:
            return inputs
        else:
            return tf.maximum(tf.minimum(inputs, 1.0), -1.0)

#        return tf.matmul(inputs, self.w) + self.b

In [None]:
# Interpolation layer
from tensorflow.keras.layers import Layer

class InterpolationLayer(Layer):

    def __init__(self, units=32, **kwargs):
        '''Initializes the instance attributes'''
        super(InterpolationLayer, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        self.features = input_shape.as_list()[3]

        # initialize the weights
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(name="kernel",
            initial_value=w_init(shape=(self.features, ),
                                 dtype='float32'),
            trainable=True)

        # # initialize the biases
        # b_init = tf.zeros_initializer()
        # self.b = tf.Variable(name="bias",
        #     initial_value=b_init(shape=(self.units,), dtype='float32'),
        #     trainable=True)

    def call(self, inputs):
        '''Defines the computation from inputs to outputs'''

        w_scaled = tf.math.sigmoid(self.w)

        counter_w = 1 - w_scaled

        conv_weights = tf.expand_dims(tf.concat([tf.expand_dims(tf.linalg.diag(w_scaled), axis=0), tf.expand_dims(tf.linalg.diag(counter_w), axis=0)], axis=0), axis=0)

        intermediate_vals = tf.nn.conv2d(inputs, conv_weights, strides=[1,1,1,1], padding="VALID")

        intermediate_vals = tf.transpose(intermediate_vals, [2, 0, 1, 3])
        out = tf.transpose(inputs, [2, 0, 1, 3])
        
        num_entries = out.shape.as_list()[0]
        out = tf.concat([out, intermediate_vals], axis=0)

        indices = list()

        num_outputs = 2*num_entries - 1

        for idx in range(num_outputs):
            if idx % 2 == 0:
                indices.append(idx // 2)
            else:
                indices.append(num_entries + idx//2)
        out = tf.gather(out, indices)
        current_layer = tf.transpose(out, [1, 2, 0, 3])

        return current_layer

In [None]:
# inherit from this base class
from tensorflow.keras.layers import Layer

class CropLayer(Layer):
    def __init__(self, x2, match_feature_dim=True, **kwargs):
        '''Initializes the instance attributes'''
        super(CropLayer, self).__init__(**kwargs)
        self.match_feature_dim = match_feature_dim
        self.x2 = x2

    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        # initialize the weights
        pass
        
    def call(self, inputs):
        '''Defines the computation from inputs to outputs'''
        if self.x2 is None:
            return inputs

        inputs = self.crop(inputs, self.x2.shape.as_list(), self.match_feature_dim)
        #ccc = tf.concat([self.x1, self.x2], axis=2)
        return inputs

    def crop(self, tensor, target_shape, match_feature_dim=True):
        '''
        Crops a 3D tensor [batch_size, width, channels] along the width axes to a target shape.
        Performs a centre crop. If the dimension difference is uneven, crop last dimensions first.
        :param tensor: 4D tensor [batch_size, width, height, channels] that should be cropped. 
        :param target_shape: Target shape (4D tensor) that the tensor should be cropped to
        :return: Cropped tensor
        '''
        shape = np.array(tensor.shape.as_list())

        ddif = shape[1] - target_shape[1]

        # diff = shape - np.array(target_shape)
        # assert(diff[0] == 0 and (diff[2] == 0 or not match_feature_dim))# Only width axis can differ
        if (ddif % 2 != 0):
            print("WARNING: Cropping with uneven number of extra entries on one side")
        # assert diff[1] >= 0 # Only positive difference allowed
        if ddif == 0:
            return tensor
        crop_start = ddif // 2
        crop_end = ddif - crop_start

        return tensor[:,crop_start:-crop_end,:]      

def crop(tensor, target_shape, match_feature_dim=True):
    '''
    Crops a 3D tensor [batch_size, width, channels] along the width axes to a target shape.
    Performs a centre crop. If the dimension difference is uneven, crop last dimensions first.
    :param tensor: 4D tensor [batch_size, width, height, channels] that should be cropped. 
    :param target_shape: Target shape (4D tensor) that the tensor should be cropped to
    :return: Cropped tensor
    '''
    shape = np.array(tensor.shape.as_list())

    ddif = shape[1] - target_shape[1]

    # diff = shape - np.array(target_shape)
    # assert(diff[0] == 0 and (diff[2] == 0 or not match_feature_dim))# Only width axis can differ
    if (ddif % 2 != 0):
        print("WARNING: Cropping with uneven number of extra entries on one side")
    # assert diff[1] >= 0 # Only positive difference allowed
    if ddif == 0:
        return tensor
    crop_start = ddif // 2
    crop_end = ddif - crop_start

    return tensor[:,crop_start:-crop_end,:]      

def audio_clip(inputs, training):
    '''Defines the computation from inputs to outputs'''
    if training:
        return inputs
    else:
        return tf.maximum(tf.minimum(inputs, 1.0), -1.0)

# def crop_and_concat(x1,x2, match_feature_dim=True):
#     '''
#     Copy-and-crop operation for two feature maps of different size.
#     Crops the first input x1 equally along its borders so that its shape is equal to 
#     the shape of the second input x2, then concatenates them along the feature channel axis.
#     :param x1: First input that is cropped and combined with the second input
#     :param x2: Second input
#     :return: Combined feature map
#     '''
#     if x2 is None:
#         return x1

#     x1 = crop(x1,x2.get_shape().as_list(), match_feature_dim)
#     return tf.concat([x1, x2], axis=2)

In [None]:
# inherit from this base class
from tensorflow.keras.layers import Layer

class DiffOutputLayer(Layer):

    def __init__(self, source_names, num_channels, filter_width, **kwargs):
        '''Initializes the instance attributes'''
        super(DiffOutputLayer, self).__init__(**kwargs)
        self.source_names = source_names
        self.num_channels = num_channels
        self.filter_width = filter_width

        self.conv1a = tf.keras.layers.Conv1D(self.num_channels, self.filter_width, padding='valid')


    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        # initialize the weights
        pass
        
    def call(self, inputs, training):
        '''Defines the computation from inputs to outputs'''
        outputs = {}
        outsss = []
        sum_source = 0
        for name in self.source_names[:-1]:
            out = self.conv1a(inputs[0])
            out = AudioClipLayer()(out)
            outputs[name] = out
            outsss.append(out)
            sum_source = sum_source + out
        

        # Compute last source based on the others
        # last_source = Utils.crop(input_mix, sum_source.shape.as_list()) - sum_source
        
        last_source = CropLayer(sum_source)(inputs[1]) - sum_source
        # last_source = crop(inputs[1], sum_source.shape.as_list())# - sum_source
        # last_source = audio_clip(last_source, training)
        last_source = AudioClipLayer()(last_source)
        print("last_source: {}".format(last_source))
        # last_source = Utils.AudioClip(last_source, training)
        outputs[self.source_names[-1]] = last_source
        #outsss.append(last_source)

        print("Reached here!!!")
        print(outputs)
        return outputs


def output_layer(source_names, num_channels, filter_width):
  pass

In [None]:
"""
Code for Wave U-Net
"""

num_initial_filters = 24
num_layers = 12
kernel_size = 15
merge_filter_size = 5
source_names = ["bass", "drums", "other", "vocals"]
num_channels = 1
output_filter_size = 1

enc_outputs = []

raw_input = tf.keras.layers.Input(shape=(147443, 1),name="raw_input")
X = raw_input
inp = raw_input
# Down-sampling
for i in range(num_layers):
  X = tf.keras.layers.Conv1D(filters=num_initial_filters + (num_initial_filters * i),
                         kernel_size=kernel_size,strides=1,
                         padding='valid', name="Down_Conv_"+str(i))(X)
  X = tf.keras.layers.LeakyReLU(name="DC_Act_"+str(i))(X)

  #print("X.outputs.shape: {}".format(X.outputs.shape))

  enc_outputs.append(X)

  X = tf.keras.layers.Lambda(lambda x: x[:,::2,:], name="Decimate_"+str(i))(X)


X = tf.keras.layers.Conv1D(filters=num_initial_filters + (num_initial_filters * num_layers),
                        kernel_size=kernel_size,strides=1,
                        padding='valid', name="Down_Conv_"+str(num_layers))(X)
X = tf.keras.layers.LeakyReLU(name="DC_Act_"+str(num_layers))(X)



# X = tf.keras.layers.Lambda(lambda x: x[:,::2,:])(X)

# X = tf.keras.layers.Reshape((1, X.shape.as_list()[1], X.shape.as_list()[2]))(X)


# Upconvolution
for i in range(num_layers):
  #UPSAMPLING
  #X = tf.expand_dims(X, axis=1)
  X = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=1), name="exp_dims_"+str(i))(X)
  # Learned interpolation between two neighbouring time positions by using a convolution filter of width 2, and inserting the responses in the middle of the two respective inputs
  X = InterpolationLayer(name="IntPol_"+str(i))(X)

  #X = tf.squeeze(X, axis=1)
  X = tf.keras.layers.Lambda(lambda x: tf.squeeze(x, axis=1), name="sq_dims_"+str(i))(X)

  # enc_outputs[-i-1] = tf.keras.layers.experimental.preprocessing.CenterCrop(X.shape[1], X.shape[2], X.shape[3])(enc_outputs[-i-1])
  # X = tf.keras.layers.Lambda(lambda x: crop_and_concat(enc_outputs[-i-1], match_feature_dim=False))(X)  
  #X = crop_and_concat(enc_outputs[-i-1], X, match_feature_dim=False)
  
  c_layer = CropLayer(X, False, name="crop_layer_"+str(i))(enc_outputs[-i-1])
  #X = CropLayer(enc_outputs[-i-1], False)(X)
  X = tf.keras.layers.Concatenate(axis=2, name="concatenate_"+str(i))([X, c_layer]) 


  X = tf.keras.layers.Conv1D(filters=num_initial_filters + (num_initial_filters * (num_layers - i - 1)),
                          kernel_size=merge_filter_size,strides=1,
                          padding='valid', name="Up_Conv_"+str(i))(X)

  X = tf.keras.layers.LeakyReLU(name="UC_Act_"+str(i))(X)

  # current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * (self.num_layers - i - 1)), self.merge_filter_size,
  #                                   activation=LeakyReLU,
  #                                   padding=self.padding)  # out = in - filter + 1

#X = crop_and_concat(inp, X, match_feature_dim=False)
c_layer = CropLayer(X, False, name="crop_layer_"+str(num_layers))(inp)
X = tf.keras.layers.Concatenate(axis=2, name="concatenate_"+str(num_layers))([X, c_layer]) 
X = AudioClipLayer(name="audio_clip_"+str(0))(X)

# Difference Output
cropped_input = CropLayer(X, False, name="crop_layer_"+str(num_layers+1))(inp)
X = DiffOutputLayer(source_names, num_channels, output_filter_size, name="diff_out")([X, cropped_input])
#X = tf.keras.layers.Lambda(lambda x: Utils.AudioClip(x, training))(X)

#out_activation = lambda x: Utils.AudioClip(x, training)

#X = InterpolationLayer()(X)
o = X
model = tf.keras.Model(inputs=raw_input, outputs=o)
model.summary()

last_source: Tensor("diff_out/audio_clip_layer_3/Maximum:0", shape=(None, 16389, 1), dtype=float32)
Reached here!!!
{'bass': <tf.Tensor 'diff_out/audio_clip_layer/Maximum:0' shape=(None, 16389, 1) dtype=float32>, 'drums': <tf.Tensor 'diff_out/audio_clip_layer_1/Maximum:0' shape=(None, 16389, 1) dtype=float32>, 'other': <tf.Tensor 'diff_out/audio_clip_layer_2/Maximum:0' shape=(None, 16389, 1) dtype=float32>, 'vocals': <tf.Tensor 'diff_out/audio_clip_layer_3/Maximum:0' shape=(None, 16389, 1) dtype=float32>}
Model: "model_24"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
raw_input (InputLayer)          [(None, 147443, 1)]  0                                            
__________________________________________________________________________________________________
Down_Conv_0 (Conv1D)            (None, 147429, 24)   384         raw_input[

In [None]:
o

{'bass': <KerasTensor: shape=(None, 16389, 1) dtype=float32 (created by layer 'diff_out')>,
 'drums': <KerasTensor: shape=(None, 16389, 1) dtype=float32 (created by layer 'diff_out')>,
 'other': <KerasTensor: shape=(None, 16389, 1) dtype=float32 (created by layer 'diff_out')>,
 'vocals': <KerasTensor: shape=(None, 16389, 1) dtype=float32 (created by layer 'diff_out')>}

In [None]:
a = np.array((5,8, 9, 11, None))

In [None]:
a

array([5, 8, 9, 11, None], dtype=object)

In [None]:
tf.keras.layers.Lambda(lambda x: x[:,::2,:])(X)