# Demucs Model

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import time

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np

%config Completer.use_jedi = False

#### weight rescaling

In [2]:
class WeightRescaling(keras.layers.Layer):
    def __init__(self, layer, **kwargs):
        super(WeightRescaling, self).__init__(**kwargs)
        self.layer = layer
        self.is_conv1d = isinstance(self.layer, keras.layers.Conv1D)
        
    def build(self, batch_input_shape):
        batch_input_shape = tf.TensorShape(batch_input_shape)
        
        if not self.layer.built:
            self.layer.build(batch_input_shape)
            kernel = self.layer.kernel
        
        kernel = self.layer.kernel
        if not self.is_conv1d:
            raise ValueError('`WeightRescaling` should wrap `keras.layers.Conv1D` layer')
        
        a = 0.1 # reference scale
        alpha = keras.backend.std(kernel) / a
        self.new_weights = kernel / tf.sqrt(alpha)
        self.layer.kernel = self.new_weights
    
    def call(self, inputs, training=True):
        outputs = self.layer(inputs)
        return outputs

#### GLU

In [3]:
class GatedLinearUnit(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(GatedLinearUnit, self).__init__(**kwargs)
        self.multiply = keras.layers.Multiply()
    def build(self, batch_input_shape):
        self.units = batch_input_shape[-1] // 2
        self.linear = keras.layers.Dense(self.units)
        self.sigmoid = keras.layers.Dense(self.units, activation='sigmoid')
    def call(self, inputs):
        outputs = self.multiply([self.linear(inputs), self.sigmoid(inputs)])
        return outputs

In [4]:
def num2size(num):
    assert num in np.arange(1, 7), "{} size is not defined".format(num)
    return 64 * 2 ** (num - 1)

In [5]:
class Encoder(keras.Model):
    def __init__(self, number, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.size = num2size(number)
        self.conv1d1 = WeightRescaling(keras.layers.Conv1D(self.size, kernel_size=8, strides=4,
                            padding='same', use_bias=False, activation=keras.activations.relu))
        self.conv1d2 = WeightRescaling(keras.layers.Conv1D(self.size*2, kernel_size=1, strides=1,
                            padding='same', use_bias=False))
        self.glu = GatedLinearUnit()
    def call(self, inputs):
        x = self.conv1d1(inputs)
        x = self.conv1d2(x)
        outputs = self.glu(x)
        return outputs

In [6]:
class Decoder(keras.Model):
    def __init__(self, number, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.size = num2size(number)
        self.conv1d1 = WeightRescaling(keras.layers.Conv1D(self.size * 2, kernel_size=3, strides=1,
                            padding='same', use_bias=False))
        self.conv1d2 = WeightRescaling(keras.layers.Conv1DTranspose(self.size, kernel_size=8, strides=4,
                            padding='same', use_bias=False, activation=keras.activations.relu))
        self.glu = GatedLinearUnit()
    def call(self, inputs):
        x = self.conv1d1(inputs)
        x = self.glu(x)
        outputs = self.conv1d2(x)
        return outputs

In [7]:
e1 = Encoder(1)
e2 = Encoder(2)
e3 = Encoder(3)
e4 = Encoder(4)
e5 = Encoder(5)
e6 = Encoder(6)

d6 = Decoder(5)
d5 = Decoder(4)
d4 = Decoder(3)
d3 = Decoder(2)
d2 = Decoder(1)

In [8]:
num_sources = 4 # vocal, bass, drum, others

d1 = keras.Sequential([
    WeightRescaling(keras.layers.Conv1D(2 * num_sources * 2, kernel_size=3, strides=1, 
                                padding='same', use_bias=False)),
    GatedLinearUnit(),
    WeightRescaling(keras.layers.Conv1DTranspose(num_sources * 2, kernel_size=8, strides=4,
                        padding='same', use_bias=False, activation=keras.activations.relu))
])

In [9]:
long = 44100 * 10
channels = 2

sample_inputs = keras.Input((long, channels))
print('sample_inputs.shape:', sample_inputs.shape)

sample_inputs.shape: (None, 441000, 2)


In [10]:
e1_outputs = e1(sample_inputs)
print('e1_output.shape:', e1_outputs.shape)

e2_outputs = e2(e1_outputs)
print('e2_output.shape:', e2_outputs.shape)

e3_outputs = e3(e2_outputs)
print('e3_output.shape:', e3_outputs.shape)

e4_outputs = e4(e3_outputs)
print('e4_output.shape:', e4_outputs.shape)

e5_outputs = e5(e4_outputs)
print('e5_output.shape:', e5_outputs.shape)

e6_outputs = e6(e5_outputs)
print('e6_output.shape:', e6_outputs.shape)

e1_output.shape: (None, 110250, 64)
e2_output.shape: (None, 27563, 128)
e3_output.shape: (None, 6891, 256)
e4_output.shape: (None, 1723, 512)
e5_output.shape: (None, 431, 1024)
e6_output.shape: (None, 108, 2048)


In [11]:
forward_layer = keras.layers.LSTM(2048, return_sequences=True)
backward_layer = keras.layers.LSTM(2048, return_sequences=True, go_backwards=True)
bi_lstm_layer = keras.layers.Bidirectional(forward_layer, backward_layer=backward_layer)
bi_outputs = bi_lstm_layer(e6_outputs)
print('bi_outputs.shape:', bi_outputs.shape)

dense_layer = keras.layers.Dense(2048)
dense_outputs = dense_layer(bi_outputs)
print('dense_outputs.shape:', dense_outputs.shape)

bi_outputs.shape: (None, 108, 4096)
dense_outputs.shape: (None, 108, 2048)


In [12]:
d6_outputs = d6(dense_outputs + e6_outputs)
print('d6_outputs.shape:', d6_outputs.shape)

d5_outputs = d5(d6_outputs[:, :-1, :] + e5_outputs)
print('d5_outputs.shape:', d5_outputs.shape)

d4_outputs = d4(d5_outputs[:, :-1, :] + e4_outputs)
print('d4_outputs.shape:', d4_outputs.shape)

d3_outputs = d3(d4_outputs[:, :-1, :] + e3_outputs)
print('d3_outputs.shape:', d3_outputs.shape)

d2_outputs = d2(d3_outputs[:, :-1, :] + e2_outputs)
print('d2_outputs.shape:', d2_outputs.shape)

d1_outputs = d1(d2_outputs[:, 1:-1, :] + e1_outputs)
print('d1_outputs.shape:', d1_outputs.shape)

d6_outputs.shape: (None, 432, 1024)
d5_outputs.shape: (None, 1724, 512)
d4_outputs.shape: (None, 6892, 256)
d3_outputs.shape: (None, 27564, 128)
d2_outputs.shape: (None, 110252, 64)
d1_outputs.shape: (None, 441000, 8)


In [13]:
## final
demucs = keras.Model(sample_inputs, d1_outputs)