Load libs and utilities. 

In [1]:
!pip install -U -q PyDrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

from google.colab import drive
drive.mount('/content/drive')
%cd "drive/MyDrive/Projects/Fourier"
!pip install import-ipynb
import import_ipynb

Mounted at /content/drive
/content/drive/MyDrive/Projects/Fourier
Collecting import-ipynb
  Downloading https://files.pythonhosted.org/packages/63/35/495e0021bfdcc924c7cdec4e9fbb87c88dd03b9b9b22419444dc370c8a45/import-ipynb-0.1.3.tar.gz
Building wheels for collected packages: import-ipynb
  Building wheel for import-ipynb (setup.py) ... [?25l[?25hdone
  Created wheel for import-ipynb: filename=import_ipynb-0.1.3-cp37-none-any.whl size=2976 sha256=be3f9d46cb95832a7fcb579be3d210152d056eb7f8c24e67b0c0e8b1aab4b407
  Stored in directory: /root/.cache/pip/wheels/b4/7b/e9/a3a6e496115dffdb4e3085d0ae39ffe8a814eacc44bbf494b5
Successfully built import-ipynb
Installing collected packages: import-ipynb
Successfully installed import-ipynb-0.1.3


In [2]:
import os
import tensorflow as tf
print("Tensorflow version: " + tf.__version__)

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
from pathlib import Path

from utils import *

Tensorflow version: 2.4.1


In [3]:
class UnitGaussianNormalizer:
    def __init__(self, x, eps=0.00001):
        super(UnitGaussianNormalizer, self).__init__()
        self.mean = tf.math.reduce_mean(x, 0)
        self.std = tf.math.reduce_std(x, 0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x):
        std = self.std + self.eps
        mean = self.mean
        x = (x * std) + mean
        return x

In [4]:
PROJECT_PATH = Path(os.path.abspath('')).parent.parent.resolve().__str__()
TRAIN_PATH = PROJECT_PATH + '/Datasets/Fourier/piececonst_r241_N1024_smooth1.mat'
TEST_PATH = PROJECT_PATH + '/Datasets/Fourier/piececonst_r241_N1024_smooth2.mat'

N_TRAIN = 1000
W = 49 #width
FTS = 32 #features
R = 5 #refinement
MODES = 12

# ...
try:
  if DATA_IS_LOADED:
    print("Not reloading data!")
except:
  reader = MatReader()
  if reader.is_not_loaded():
    reader.load_file(TRAIN_PATH)

DATA_IS_LOADED = True
# ...

x_train = reader.read_field('coeff')[:N_TRAIN,::R,::R]
y_train = reader.read_field('sol')[:N_TRAIN,::R,::R]

S_ = x_train.shape[1]
grids = []
grids.append(np.linspace(0, 1, S_))
grids.append(np.linspace(0, 1, S_))
grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T
grid = grid.reshape(1,S_,S_,2)

print(x_train.shape)

x_train = tf.convert_to_tensor(x_train, dtype=tf.float32)
y_train = tf.convert_to_tensor(y_train, dtype=tf.float32)
grid = tf.convert_to_tensor(grid, dtype=tf.float32)
x_train = tf.expand_dims(x_train, axis=3)
grid = tf.repeat(grid, repeats = N_TRAIN, axis = 0)
x_train = tf.concat([x_train, grid], axis=3)
y_train = tf.expand_dims(y_train, axis=3)

x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

print("x_train dims: " + str(x_train.shape))
print("y_train dims: " + str(y_train.shape))

(1000, 49, 49)
x_train dims: (1000, 49, 49, 3)
y_train dims: (1000, 49, 49, 1)


In [5]:
class FourierLayer(layers.Layer):
    def __init__(self):
        super(FourierLayer, self).__init__()  
        self.weight_fft1 = tf.Variable(tf.random.uniform([FTS, FTS, MODES, MODES], minval=0, maxval=1),name="Wfft1", trainable=True)
        self.weight_fft2 = tf.Variable(tf.random.uniform([FTS, FTS, MODES, MODES], minval=0, maxval=1),name="Wfft2", trainable=True)

    def call(self, input, training=True):
        weight_fft_complex = tf.complex(self.weight_fft1, self.weight_fft2)
        x = input
        x = keras.layers.Lambda(lambda v: tf.signal.rfft2d(v, tf.constant([49, 49])))(x)
        x = x[:,:,:MODES, :MODES]
        x = keras.layers.Lambda(lambda v: tf.einsum('ioxy,bixy->boxy', weight_fft_complex, v))(x)
        x = keras.layers.Lambda(lambda v: tf.signal.irfft2d(v, tf.constant([49, 49])))(x)
        return x

In [6]:
class FourierUnit(layers.Layer):
    def __init__(self):
        super(FourierUnit, self).__init__()
        self.W = tf.keras.layers.Conv1D(W, 1)
        self.fourier = FourierLayer()        
        self.add = tf.keras.layers.Add()
        self.bn = tf.keras.layers.BatchNormalization()

    def call(self, input, training=True):
        x = input
        x1 = self.fourier(x)
        x2 = self.W(x)
        x = self.add([x1, x2])
        x = self.bn(x)
        return x

In [7]:
class MyModel(keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc0 = tf.keras.layers.Dense(FTS)
        self.perm_pre = tf.keras.layers.Permute((3, 1, 2))
        
        self.fourier_unit_1 = FourierUnit()
        self.relu_1 = tf.keras.layers.ReLU()

        self.fourier_unit_2 = FourierUnit()

        self.relu = tf.keras.layers.ReLU()
        self.perm_post = tf.keras.layers.Permute((2, 3, 1))
        self.fc1 = tf.keras.layers.Dense(128)
        self.relu2 = tf.keras.layers.ReLU()
        self.fc2 = tf.keras.layers.Dense(1)

    def call(self, input):
        x = self.fc0(input)
        x = self.perm_pre(x)
        
        x = self.fourier_unit_1(x)
        x = self.relu_1(x)

        x = self.fourier_unit_2(x)

        x = self.perm_post(x)
        x = self.fc1(x)
        x = self.relu2(x)
        x = self.fc2(x)
        return x

    def model(self):
        x = keras.Input(shape=(W, W, 3))
        return keras.Model(inputs=[x], outputs=self.call(x))

In [8]:
model = MyModel()
mse = tf.keras.losses.MeanSquaredError()
model.compile(
    loss=mse,
    optimizer=keras.optimizers.Adam(lr=3e-4),
    metrics=[tf.keras.metrics.RootMeanSquaredError()],
)
model.fit(x_train, y_train, batch_size=64, epochs=2, verbose=2)
model.model().summary()

Epoch 1/2
16/16 - 36s - loss: 1.0500 - root_mean_squared_error: 1.0247
Epoch 2/2
16/16 - 1s - loss: 0.7540 - root_mean_squared_error: 0.8683
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 49, 49, 3)]       0         
_________________________________________________________________
dense (Dense)                (None, 49, 49, 32)        128       
_________________________________________________________________
permute (Permute)            (None, 32, 49, 49)        0         
_________________________________________________________________
fourier_unit (FourierUnit)   (None, 32, 49, 49)        297558    
_________________________________________________________________
re_lu (ReLU)                 (None, 32, 49, 49)        0         
_________________________________________________________________
fourier_unit_1 (FourierUnit) (None, 32, 49, 49)     