In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import xarray as xr
import numpy as np

2022-06-07 15:28:21.336637: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
# Load data
train_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/train_data.nc")
validation_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/validation_data.nc")

In [12]:
# Build a layer that implements 4 discrete rotational invariance
class RotInvConv2D(tf.keras.layers.Layer):
    def __init__(self, out_features, filt_shape, rot_axis=True, activation=tf.nn.relu, **kwargs):
        super().__init__(**kwargs)
        self.out_features = out_features
        self.filt_shape = filt_shape
        self.rot_axis = rot_axis
        self.activation = activation

    def build(self, input_shape):  # Create the layer when it is first called
        self.in_channels = input_shape[-1]
        if self.rot_axis:
            self.filt_shape_append = [1, self.in_channels, self.out_features]  # Rotation dimension, then number of channels in, number of base channels out
        else:
            self.filt_shape_append = [self.in_channels, self.out_features]  # Omitting rotation dimension if it has not yet been created
        self.filt_shape = tf.concat([
            self.filt_shape,  # Spatial dimensions
            self.filt_shape_append
        ], axis=0)
        self.filt_base = tf.Variable(
            tf.random.normal(self.filt_shape)  # Random initialization of filters
        )

    def call(self, inputs):  # Does the actual computation for each rotation
        self.rot90 = tf.experimental.numpy.rot90  # Use TF implementation of numpy's rot90
        if self.rot_axis:  # If the rotation axis exists, concatenate along it
            return self.activation(tf.concat([
                tf.nn.convolution(inputs, self.filt_base),
                tf.nn.convolution(inputs, self.rot90(self.filt_base, k=1)),
                tf.nn.convolution(inputs, self.rot90(self.filt_base, k=2)),
                tf.nn.convolution(inputs, self.rot90(self.filt_base, k=3)),
            ], axis=-2))
        else:
            return self.activation(tf.stack([  # Otherwise, create it then concatenate
                tf.nn.convolution(inputs, self.filt_base),
                tf.nn.convolution(inputs, self.rot90(self.filt_base, k=1)),
                tf.nn.convolution(inputs, self.rot90(self.filt_base, k=2)),
                tf.nn.convolution(inputs, self.rot90(self.filt_base, k=3)),
            ], axis=-2))

In [9]:
model = models.Sequential()
model.add(RotInvConv2D(8, (3, 3), rot_axis=False, input_shape=(128, 128, 1)))
model.add(RotInvConv2D(8, (3, 3)))

In [11]:
model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_inv_conv2d_3 (RotInvCon  (None, 126, 126, 4, 8)   72        
 v2D)                                                            
                                                                 
 rot_inv_conv2d_4 (RotInvCon  (None, 124, 124, 16, 8)  576       
 v2D)                                                            
                                                                 
Total params: 648
Trainable params: 648
Non-trainable params: 0
_________________________________________________________________
