In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
import tensorflow as tf
import keras
from keras import layers

In [28]:
class ComplexConv2D(layers.Layer):
    def __init__(self, filters, kernel_size, strides=(1, 1), padding='same'):
        super(ComplexConv2D, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding

        self.real_conv2D = layers.Conv2D(filters, kernel_size, strides = strides, padding = padding)
        self.complex_conv2D = layers.Conv2D(filters, kernel_size, strides = strides, padding = padding)

    def call(self, inputs):
        real_stft, img_stft = tf.split(inputs, 2, axis=-1)

        real_stft_real = self.real_conv2D(real_stft)
        img_stft_real = self.real_conv2D(img_stft)

        real_stft_img = self.complex_conv2D(real_stft)
        img_stft_img = self.complex_conv2D(img_stft)

        output_real = real_stft_real - img_stft_img
        output_img = real_stft_img + img_stft_real

        return tf.concat([output_real, output_img], axis=-1)
        

In [29]:
class ComplexBNPReLu(layers.Layer):
    def __init__(self):
        super(ComplexBNPReLu, self).__init__()
        self.real_bn = layers.BatchNormalization()
        self.img_bn = layers.BatchNormalization()
        self.real_prelu = layers.PReLU()
        self.img_prelu = layers.PReLU()


    def call(self, inputs):
        real, img = tf.split(inputs, 2, axis=-1)

        real = self.real_bn(real)
        img = self.img_bn(img)

        real = self.real_prelu(real)
        img = self.img_prelu(img)

        return tf.concat([real, img], axis=-1)

In [30]:
class ComplexEncodeBlock(layers.Layer):
    def __init__(self, filters, kernel_size=(3,3), strides=(1,1)):
        super(ComplexEncodeBlock, self).__init__()
        self.conv = ComplexConv2D(filters, kernel_size, strides=strides)
        self.bn_prelu = ComplexBNPReLu()

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.bn_prelu(x)
        return x

In [31]:
import tensorflow as tf
from tensorflow.keras import layers

class ComplexLSTM(layers.Layer):
    def __init__(self, units, **kwargs):
        super(ComplexLSTM, self).__init__(**kwargs)
        self.units = units
        
    def build(self, input_shape):
        # input_shape: (batch, time, height, width)
        height, width = input_shape[2], input_shape[3]
        
        # Option 1: Use ConvLSTM2D for spatial-temporal processing
        self.conv_lstm_rr = layers.ConvLSTM2D(self.units//4, (3, 3), padding='same', return_sequences=True)
        self.conv_lstm_ri = layers.ConvLSTM2D(self.units//4, (3, 3), padding='same', return_sequences=True)
        self.conv_lstm_ir = layers.ConvLSTM2D(self.units//4, (3, 3), padding='same', return_sequences=True)
        self.conv_lstm_ii = layers.ConvLSTM2D(self.units//4, (3, 3), padding='same', return_sequences=True)
        
        # Build the layers
        conv_input_shape = (input_shape[0], input_shape[1], height, width//2, 1)
        self.conv_lstm_rr.build(conv_input_shape)
        self.conv_lstm_ri.build(conv_input_shape)
        self.conv_lstm_ir.build(conv_input_shape)
        self.conv_lstm_ii.build(conv_input_shape)
        
        super(ComplexLSTM, self).build(input_shape)
    
    def call(self, inputs):
        # inputs shape: (batch, time, height, width)
        height, width = inputs.shape[2], inputs.shape[3]
        
        # Split into real and imaginary parts
        real_part = inputs[..., :width//2]  # (batch, time, height, width/2)
        imag_part = inputs[..., width//2:]  # (batch, time, height, width/2)
        
        # Add channel dimension for ConvLSTM2D
        real_part = tf.expand_dims(real_part, -1)  # (batch, time, height, width/2, 1)
        imag_part = tf.expand_dims(imag_part, -1)  # (batch, time, height, width/2, 1)
        
        # Apply complex ConvLSTM operations
        Frr = self.conv_lstm_rr(real_part)
        Fri = self.conv_lstm_ri(imag_part)
        Fir = self.conv_lstm_ir(real_part)
        Fii = self.conv_lstm_ii(imag_part)
        
        # Complex operations
        real_output = Frr - Fii
        imag_output = Fri + Fir
        
        # Concatenate along width dimension and remove channel dim
        real_output = tf.squeeze(real_output, -1)
        imag_output = tf.squeeze(imag_output, -1)
        output = tf.concat([real_output, imag_output], axis=-1)
        
        return output
    
    def compute_output_shape(self, input_shape):
        # Output shape maintains spatial dimensions
        return (input_shape[0], input_shape[1], input_shape[2], self.units//2)
    
    def get_config(self):
        config = super(ComplexLSTM, self).get_config()
        config.update({'units': self.units})
        return config

In [32]:
class ComplexDeconv2D(layers.Layer):
    def __init__(self, filters, kernel_size, strides=(2,2), padding='same'):
        super(ComplexDeconv2D, self).__init__()

        self.real_deconv = layers.Conv2DTranspose(filters, kernel_size, strides, padding)
        self.img_deconv = layers.Conv2DTranspose(filters, kernel_size, strides, padding)

    def call(self, inputs):
        real, img = tf.split(inputs, 2, axis=-1)

        real_stft_real = self.real_deconv(real)
        real_stft_img = self.img_deconv(real)

        img_stft_real = self.real_deconv(img)
        img_stft_img = self.img_deconv(img)

        output_real = real_stft_real - img_stft_img
        output_img = real_stft_img + img_stft_real

        return tf.concat([output_real, output_img], axis=-1)

In [51]:
import tensorflow as tf
from tensorflow.keras import layers

class ComplexDecoderBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(ComplexDecoderBlock, self).__init__(**kwargs)
        self.filters = filters
        
    def build(self, input_shape):
        # Dense layers for feature transformation
        self.dense1 = layers.Dense(self.filters * 2)
        self.dense2 = layers.Dense(self.filters)
        self.batch_norm1 = layers.BatchNormalization()
        self.batch_norm2 = layers.BatchNormalization()
        
        super(ComplexDecoderBlock, self).build(input_shape)
    
    def call(self, x, skip):
        print(f"Input x shape: {x.shape}")
        print(f"Input skip shape: {skip.shape}")
        
        # Simple approach: just handle the tensor shapes and concatenate
        batch_size = tf.shape(x)[0]
        
        # Upsample x spatially to match skip if needed
        if tf.shape(x)[2] != tf.shape(skip)[2] or tf.shape(x)[3] != tf.shape(skip)[3]:
            # Reshape for upsampling
            x_reshaped = tf.expand_dims(x, -1)  # Add channel dim
            x_reshaped = tf.reshape(x_reshaped, [-1, tf.shape(x)[2], tf.shape(x)[3], 1])
            
            # Upsample
            target_h, target_w = tf.shape(skip)[2], tf.shape(skip)[3]
            x_upsampled = tf.image.resize(x_reshaped, [target_h, target_w])
            
            # Reshape back
            x_upsampled = tf.reshape(x_upsampled, [batch_size, tf.shape(x)[1], target_h, target_w])
            x = tf.squeeze(x_upsampled, -1) if len(x_upsampled.shape) == 5 else x_upsampled
        
        # Match time dimensions
        min_time = tf.minimum(tf.shape(x)[1], tf.shape(skip)[1])
        x = x[:, :min_time, :, :]
        skip = skip[:, :min_time, :, :]
        
        # Concatenate along width dimension
        x_concat = tf.concat([x, skip], axis=-1)
        print(f"After concat: {x_concat.shape}")
        
        # Apply dense transformations along the last dimension
        x_out = self.dense1(x_concat)
        x_out = self.batch_norm1(x_out)
        x_out = tf.nn.relu(x_out)
        
        x_out = self.dense2(x_out)
        x_out = self.batch_norm2(x_out)
        x_out = tf.nn.relu(x_out)
        
        print(f"Final output: {x_out.shape}")
        return x_out
    
    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            x_shape, skip_shape = input_shape
            batch_size = x_shape[0]
            time_dim = min(x_shape[1], skip_shape[1]) if skip_shape[1] is not None else x_shape[1]
            height_dim = max(x_shape[2], skip_shape[2]) if skip_shape[2] is not None else x_shape[2]
        else:
            x_shape = input_shape
            batch_size = x_shape[0]
            time_dim = x_shape[1]
            height_dim = x_shape[2]
            
        return (batch_size, time_dim, height_dim, self.filters)
    
    def get_config(self):
        config = super(ComplexDecoderBlock, self).get_config()
        config.update({'filters': self.filters})
        return config

In [52]:
def dccrnModel(input_shape=(282, 256, 512)):
    inputs = layers.Input(shape=input_shape)
    x = inputs
    skips = []
    
    # Encoder
    for i, filters in enumerate([32, 64, 128, 256, 256]):
        x = ComplexEncodeBlock(filters)(x)
        if i < 4:  # Don't save the last layer as skip
            skips.append(x)
        print(f"After encoder block {i}: {x.shape}")
    
    # LSTM - make sure it preserves or predictably changes dimensions
    x = ComplexLSTM(units=256)(x)
    print(f"After LSTM: {x.shape}")
    
    # Decoder - reverse the filters and match with skips
    decoder_filters = [256, 128, 64, 32]
    for i, filters in enumerate(decoder_filters):
        if i < len(skips):
            skip = skips[-(i+1)]  # Get skip connection in reverse order
            print(f"Decoder {i}: x={x.shape}, skip={skip.shape}")
            x = ComplexDecoderBlock(filters)(x, skip)
        else:
            # No skip connection for this layer
            x = ComplexDecoderBlock(filters)(x, x)  # Pass x as both inputs
        print(f"After decoder block {i}: {x.shape}")
    
    output = ComplexConv2D(filters=1, kernel_size=(1,1))(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model

In [53]:
model = dccrnModel()
model.summary()

After encoder block 0: (None, 282, 256, 64)
After encoder block 1: (None, 282, 256, 128)
After encoder block 2: (None, 282, 256, 256)
After encoder block 3: (None, 282, 256, 512)
After encoder block 4: (None, 282, 256, 512)
After LSTM: (None, 282, 256, 128)
Decoder 0: x=(None, 282, 256, 128), skip=(None, 282, 256, 512)
After decoder block 0: (None, 282, 256, 256)
Decoder 1: x=(None, 282, 256, 256), skip=(None, 282, 256, 256)
After decoder block 1: (None, 282, 256, 128)
Decoder 2: x=(None, 282, 256, 128), skip=(None, 282, 256, 128)
After decoder block 2: (None, 282, 256, 64)
Decoder 3: x=(None, 282, 256, 64), skip=(None, 282, 256, 64)
After decoder block 3: (None, 282, 256, 32)
