In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow.experimental.numpy as tfnp
import xarray as xr
import numpy as np



In [67]:
# Load data
train_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/train_data.nc")
validation_ds = xr.open_dataset("/glade/scratch/lverhoef/gdl_toy_ds/validation_data.nc").isel(p=slice(0, 1000))

In [82]:
train_ds

In [71]:
# Layer that implements 4 discrete rotational equivariance
# Note that this uses group convolution on Z^2 x C_4, where rotation is about the center of the domain.
class RotEquivConv2D(tf.keras.layers.Layer):
    def __init__(self, out_features, filt_shape, rot_axis=True, activation=tf.nn.relu, **kwargs):
        super().__init__(**kwargs)
        self.out_features = out_features
        self.filt_shape = filt_shape
        self.rot_axis = rot_axis
        self.activation = activation

    def build(self, input_shape):  # Create the layer when it is first called
        self.in_features = input_shape[-1]
        self.filt_shape = tf.concat([
            self.filt_shape,  # Spatial dimensions
            [self.in_features, self.out_features]
        ], axis=0)
        self.filt_base = tf.Variable(
            tf.random.normal(self.filt_shape)  # Random initialization of filters
        )

    def call(self, inputs):  # Does the actual computation for each rotation
        if self.rot_axis:  # If we're already in Z^2 x C_4, convolve along each rotational layer
            return self.activation(tf.stack([
                    tf.nn.convolution(
                        tfnp.take(inputs, i, axis=-2),
                        self.filt_base)
                    for i in range(inputs.shape[-2])],
                axis=-2
            ))
        else:  # If we're not yet in the group domain, move to it.
            return self.activation(tf.stack([
                tf.nn.convolution(inputs, self.filt_base),
                tfnp.rot90(tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=1)), k=4-1),
                tfnp.rot90(tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=2)), k=4-2),
                tfnp.rot90(tf.nn.convolution(inputs, tfnp.rot90(self.filt_base, k=3)), k=4-3),
            ], axis=-2))

In [48]:
# 2D pooling layer that pools within each rotational dimension
class RotEquivPool2D(tf.keras.layers.Layer):
    def __init__(self, pool_size, pool_method=tf.keras.layers.MaxPool2D, **kwargs):
        super().__init__(**kwargs)
        self.pool_size = pool_size
        self.pool_method = pool_method
        self.pool = self.pool_method(pool_size=self.pool_size)

    def call(self, inputs):
        return tf.stack(
            [self.pool(tfnp.take(inputs, k, axis=-2)) for k in range(inputs.shape[-2])],
            axis=-2
        )

In [50]:
# Rotational invariant pooling that pools across the rotational dimensions
class RotInvPool(tf.keras.layers.Layer):
    def __init__(self, pool_method='max', **kwargs):
        valid_methods = {'max', 'mean'}
        if pool_method not in valid_methods:
            raise ValueError(f'pool_method must be one of {valid_methods}')

        super().__init__(**kwargs)
        if pool_method == "max":
            self.pool_method = tf.math.reduce_max
        else:
            self.pool_method = tf.math.reduce_mean

    def call(self, inputs):
        return self.pool_method(inputs, axis=-2)

In [79]:
model = models.Sequential()
model.add(RotEquivConv2D(32, (3, 3), rot_axis=False, input_shape=(128, 128, 1)))
model.add(RotEquivPool2D((2, 2)))
model.add(RotEquivConv2D(32, (3, 3)))
model.add(RotEquivPool2D((2, 2)))
model.add(RotEquivConv2D(64, (3, 3)))
model.add(RotEquivPool2D((2, 2)))
model.add(RotEquivConv2D(64, (3, 3)))
model.add(RotEquivPool2D((2, 2)))
model.add(RotEquivConv2D(128, (3, 3)))
model.add(RotInvPool())
model.add(layers.Flatten())
model.add(layers.Dense(32))
model.add(layers.Dense(2))

In [80]:
model.compile(
    optimizer='adam',
    loss='mse',
    metrics=['mse']
)

In [81]:
model.summary()

Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_20 (RotEqu  (126, 126, 126, 4, 32)   288       
 ivConv2D)                                                       
                                                                 
 rot_equiv_pool2d_22 (RotEqu  (126, 63, 63, 4, 32)     0         
 ivPool2D)                                                       
                                                                 
 rot_equiv_conv2d_21 (RotEqu  (126, 61, 61, 4, 32)     9216      
 ivConv2D)                                                       
                                                                 
 rot_equiv_pool2d_23 (RotEqu  (126, 30, 30, 4, 32)     0         
 ivPool2D)                                                       
                                                                 
 rot_equiv_conv2d_22 (RotEqu  (126, 28, 28, 4, 64)    

In [None]:
history = model.fit(
    train_ds['ellipse'], train_ds[['maj_len', 'min_len']],
    epochs=10, validation_data=(validation_ds['ellipse'], validation_ds[['maj_len', 'min_len']])
)