In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow.experimental.numpy as tfnp
import xarray as xr
import numpy as np



In [2]:
# Load data
train_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/train_data.nc")
validation_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/validation_data.nc")

In [13]:
# Build a layer that implements 4 discrete rotational invariance
class RotEquivConv2D(tf.keras.layers.Layer):
    def __init__(self, out_features, filt_shape, rot_axis=True, activation=tf.nn.relu, **kwargs):
        super().__init__(**kwargs)
        self.out_features = out_features
        self.filt_shape = filt_shape
        self.rot_axis = rot_axis
        self.activation = activation

    def build(self, input_shape):  # Create the layer when it is first called
        self.in_features = input_shape[-1]
        self.filt_shape = tf.concat([
            self.filt_shape,  # Spatial dimensions
            [self.in_features, self.out_features]
        ], axis=0)
        self.filt_base = tf.Variable(
            tf.random.normal(self.filt_shape)  # Random initialization of filters
        )

    def call(self, inputs):  # Does the actual computation for each rotation
        if self.rot_axis:  # If the rotation axis exists, concatenate along it
            return self.activation(tf.stack([
                    tf.nn.convolution(
                        tfnp.take(inputs, i, axis=-2),
                        tfnp.rot90(self.filt_base, k=i))
                    for i in range(inputs.shape[-2])],
                axis=-2
            ))
        else:
            return self.activation(tf.stack([  # Otherwise, create it then concatenate
                tf.nn.convolution(inputs, self.filt_base),
                tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=1)),
                tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=2)),
                tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=3)),
            ], axis=-2))

In [34]:
# Build a 2D pooling layer that pools within each rotational dimension
class RotEquivPool2D(tf.keras.layers.Layer):
    def __init__(self, pool_size, pool_method=tf.keras.layers.MaxPool2D, **kwargs):
        super().__init__(**kwargs)
        self.pool_size = pool_size
        self.pool_method = pool_method
        self.pool = self.pool_method(pool_size=self.pool_size)

    def call(self, inputs):
        return tf.stack(
            [self.pool(tfnp.take(inputs, k, axis=-2)) for k in range(inputs.shape[-2])],
            axis=-2
        )

In [41]:
model = models.Sequential()
model.add(RotEquivConv2D(8, (3, 3), rot_axis=False, input_shape=(128, 128, 1)))
model.add(RotEquivPool2D((2, 2)))
model.add(RotEquivConv2D(16, (3, 3)))
model.add(RotEquivPool2D((2, 2)))
model.add(RotEquivConv2D(32, (3, 3)))
model.add(RotEquivPool2D((2, 2)))

NameError: name 'RotEquivConv2D' is not defined

In [21]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_inv_conv2d_4 (RotInvCon  (None, 126, 126, 4, 8)   72        
 v2D)                                                            
                                                                 
 rot_inv_conv2d_5 (RotInvCon  (None, 124, 124, 4, 16)  1152      
 v2D)                                                            
                                                                 
 rot_inv_conv2d_6 (RotInvCon  (None, 122, 122, 4, 32)  4608      
 v2D)                                                            
                                                                 
Total params: 5,832
Trainable params: 5,832
Non-trainable params: 0
_________________________________________________________________


In [37]:
tfnp.take(B, 0, axis=-1)

<tf.Tensor: shape=(3, 3, 3, 4), dtype=float32, numpy=
array([[[[0.        , 0.        , 0.43772304, 0.        ],
         [1.3640386 , 0.        , 1.7690954 , 0.        ],
         [0.        , 2.5217226 , 0.        , 1.5063541 ]],

        [[1.4709724 , 0.        , 2.8582454 , 0.        ],
         [0.        , 0.        , 0.        , 1.1099697 ],
         [0.        , 1.4837714 , 0.13736852, 2.8603058 ]],

        [[0.        , 5.5206537 , 0.        , 4.1514945 ],
         [1.1968007 , 0.        , 3.8705645 , 0.        ],
         [0.66617167, 1.6762189 , 3.0199127 , 2.7503014 ]]],


       [[[2.2047439 , 0.5287634 , 2.4941406 , 1.5711395 ],
         [0.        , 3.7656643 , 0.        , 1.0882193 ],
         [0.        , 0.        , 0.39570844, 1.140449  ]],

        [[0.        , 0.        , 0.        , 2.1528885 ],
         [0.        , 1.002754  , 0.        , 0.5949199 ],
         [0.        , 0.6840533 , 1.1966977 , 1.9588431 ]],

        [[0.        , 0.        , 4.601345  , 0. 

In [40]:
pool = RotEquivPool2D((2, 2))
tfnp.take(pool(B), 0, -1)

<tf.Tensor: shape=(3, 1, 1, 4), dtype=float32, numpy=
array([[[[1.4709724, 0.       , 2.8582454, 1.1099697]]],


       [[[2.2047439, 3.7656643, 2.4941406, 2.1528885]]],


       [[[0.631684 , 1.3524477, 1.3281293, 0.       ]]]], dtype=float32)>