In [7]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [8]:
import tensorflow as tf
import keras
from keras import layers

In [35]:
class complexConv2D(layers.Layer):
    def __init__(self, filters, kernel_size, strides=(1, 1), padding='same'):
        super(complexConv2D, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding

        self.real_conv2D = layers.Conv2D(filters, kernel_size, strides = strides, padding = padding)
        self.complex_conv2D = layers.Conv2D(filters, kernel_size, strides = strides, padding = padding)

    def call(self, input_stft):
        real_stft, img_stft = tf.split(input_stft, axis=-1)

        real_stft_real = self.real_conv2D(real_stft)
        real_stft_img = self.complex_conv2D(real_stft)

        img_stft_real = self.real_conv2D(img_stft)
        img_stft_img = self.complex_conv2D(img_stft)

        output_real = real_stft_real - img_stft_img
        output_img = real_stft_img + img_stft_real

        return tf.stack([output_real, output_img], axis=-1)
        

In [36]:
class complexBN_PReLu(layers.Layer):
    def __init__(self):
        super(complexBN_PReLu, self).__init__()
        self.real_bn = layers.BatchNormalization()
        self.img_bn = layers.BatchNormalization()
        self.real_prelu = layers.PReLU()
        self.img_prelu = layers.PReLU()


    def call(self, inputs):
        real, img = tf.split(inputs, 2, axis=-1)

        real = self.real_bn(real)
        img = self.img_bn(img)

        real = self.real_prelu(real)
        img = self.img_prelu(img)

        return tf.stack([real, img], axis=-1)

In [37]:
class complexEncodeBlock(layers.Layer):
    def __init__(self, filters, kernel_size=(3,3), strides=(1,1)):
        super(complexEncodeBlock, self).__init__()
        self.conv = complexConv2D(filters, kernel_size, strides=strides)
        self.bn_prelu = complexBN_PReLu()

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.bn_prelu(x)
        return x

In [38]:
class complexLSTM(layers.Layer):
    def __init__(self, unit):
        super(complexLSTM, self).__init__()
        self.real_lstm = layers.LSTM(unit, return_sequences=True)
        self.img_lstm = layers.LSTM(unit, return_sequences=True)

    def call(self, inputs):
        B, T, F, C = inputs.shape
        r,img = tf.split(inputs, 2, axis=-1)

        r = tf.reshape(r, (B, T, -1))
        img = tf.reshape(img, (B, T, -1))

        Frr = self.real_lstm(r)
        Fir = self.real_lstm(img)
        Fri = self.img_lstm(r)
        Fii = self.img_lstm(img)

        out_real = Frr - Fii
        out_img = Fri + Fir
        
        F_dim = out_real.shape[-1]
        out_real = tf.reshape(out_real, (B, T, F_dim//2, 1))
        out_img = tf.reshape(out_img, (B, T, F_dim//2, 1))

        return tf.stack([out_real, out_img], axis=-1)

In [39]:
class complexDeconv2D(layers.Layer):
    def __init__(self, filters, kernel_size, strides=(2,2), padding='same'):
        super(complexDeconv2D, self).__init__()

        self.real_deconv = layers.Conv2DTranspose(filters, kernel_size, strides, padding)
        self.img_deconv = layers.Conv2DTranspose(filters, kernel_size, strides, padding)

    def call(self, inputs):
        real, img = tf.spkit(inuts, 2, axis=-1)

        real_stft_real = self.real_deconv(real)
        real_stft_img = self.img_deconv(real)

        img_stft_real = self.real_deconv(img)
        img_stft_img = self.img_deconv(img)

        output_real = real_stft_real - img_stft_img
        output_img = real_stft_img + img_stft_real

        return tf.stack([output_real, output_img], axis=-1)

In [40]:
class complexDecoderBlock(layers.Layer):
    def __init__(self, filters, kernel_size=(3,3), strides=(2,2)):
        super(complexDecoderBlock, self).__init__()
        self.deconv = complexDeconv2D(filters, kernel_size, strides)
        self.batchNormActivation = complexBN_PReLu()

    def call(self, x, skip):
        x = tf.stack([x, skip], axis=-1)
        x = self.deconv(x)
        x = self.batchNormActivation(x)
        return x

In [41]:
def dccrnModel(input_shape=(282, 256, 2)):
    inputs = tf.keras.Input(shape = input_shape)
    x = inputs
    skips = []

    for filters in [32, 64, 128, 256, 256]:
        x = complexEncodeBlock(filters)(x)
        skips.append(x)

    x = complexLSTM(unit = 256)(x)

    for i, filters in enumerate([256, 256, 128, 64, 32]):
        skip = skips[-(i+1)]
        x = complexDecoderBlock(filters)(x, skip)

    output = complexConv2D(filters = 1, kernel_size=(1,1))(x)
    return tf.keras.Model(inputs=inputs, outputs=output)

In [42]:
model = dccrnModel()
model.summary()

1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Exception encountered when calling complexConv2D.call().

[1mMissing required positional argument[0m

Arguments received by complexConv2D.call():
  • input_stft=tf.Tensor(shape=(None, 282, 256, 2), dtype=float32)''


TypeError: Exception encountered when calling complexEncodeBlock.call().

[1mCould not automatically infer the output shape / dtype of 'complex_encode_block_2' (of type complexEncodeBlock). Either the `complexEncodeBlock.call()` method is incorrect, or you need to implement the `complexEncodeBlock.compute_output_spec() / compute_output_shape()` method. Error encountered:

Exception encountered when calling complexConv2D.call().

[1mMissing required positional argument[0m

Arguments received by complexConv2D.call():
  • input_stft=tf.Tensor(shape=(None, 282, 256, 2), dtype=float32)[0m

Arguments received by complexEncodeBlock.call():
  • args=('<KerasTensor shape=(None, 282, 256, 2), dtype=float32, sparse=False, name=keras_tensor_2>',)
  • kwargs=<class 'inspect._empty'>