In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow.experimental.numpy as tfnp
import xarray as xr
import numpy as np



In [2]:
# Load data
train_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/train_data.nc")
validation_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/validation_data.nc")

In [13]:
# Build a layer that implements 4 discrete rotational invariance
class RotInvConv2D(tf.keras.layers.Layer):
    def __init__(self, out_features, filt_shape, rot_axis=True, activation=tf.nn.relu, **kwargs):
        super().__init__(**kwargs)
        self.out_features = out_features
        self.filt_shape = filt_shape
        self.rot_axis = rot_axis
        self.activation = activation

    def build(self, input_shape):  # Create the layer when it is first called
        self.in_features = input_shape[-1]
        self.filt_shape = tf.concat([
            self.filt_shape,  # Spatial dimensions
            [self.in_features, self.out_features]
        ], axis=0)
        self.filt_base = tf.Variable(
            tf.random.normal(self.filt_shape)  # Random initialization of filters
        )

    def call(self, inputs):  # Does the actual computation for each rotation
        if self.rot_axis:  # If the rotation axis exists, concatenate along it
            return self.activation(tf.stack([
                tf.nn.convolution(tfnp.take(inputs, 0, axis=-2), self.filt_base),
                tf.nn.convolution(tfnp.take(inputs, 1, axis=-2), tfnp.rot90(self.filt_base, k=1)),
                tf.nn.convolution(tfnp.take(inputs, 2, axis=-2), tfnp.rot90(self.filt_base, k=2)),
                tf.nn.convolution(tfnp.take(inputs, 3, axis=-2), tfnp.rot90(self.filt_base, k=3)),
            ], axis=-2))
        else:
            return self.activation(tf.stack([  # Otherwise, create it then concatenate
                tf.nn.convolution(inputs, self.filt_base),
                tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=1)),
                tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=2)),
                tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=3)),
            ], axis=-2))

In [17]:
model = models.Sequential()
model.add(RotInvConv2D(8, (3, 3), rot_axis=False, input_shape=(128, 128, 1)))
model.add(RotInvConv2D(16, (3, 3)))
model.add(RotInvConv2D(32, (3, 3)))

In [18]:
model.summary()

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_inv_conv2d_8 (RotInvCon  (None, 126, 126, 4, 8)   72        
 v2D)                                                            
                                                                 
 rot_inv_conv2d_9 (RotInvCon  (None, 124, 124, 16, 16)  1152     
 v2D)                                                            
                                                                 
 rot_inv_conv2d_10 (RotInvCo  (None, 122, 122, 64, 32)  4608     
 nv2D)                                                           
                                                                 
Total params: 5,832
Trainable params: 5,832
Non-trainable params: 0
_________________________________________________________________


In [16]:
A.shape

TensorShape([3, 4, 4, 2])

In [14]:
init_layer = RotInvConv2D(8, (2, 2), False)
B = init_layer(A)

In [15]:
B.shape

TensorShape([3, 3, 3, 4, 8])

In [17]:
next_layer = RotInvConv2D(16, (2, 2))
C = next_layer(B)

In [19]:
C.shape

TensorShape([3, 2, 2, 4, 16])