In [None]:
import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
import tensorflow as tf
import cv2
import numpy as np
from diametery.fiber import Fiber, Image, img_size
from tensorflow.keras.optimizers import Adam
from tensorflow.data import Dataset

In [None]:

def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + [1,])

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 1, strides=1, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 1, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 1, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 1, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 1, padding="same")(x)
        x = layers.BatchNormalization()(x)

        # Project residual
        if filters != 32:
            x = layers.UpSampling2D(2)(x)
            residual = layers.UpSampling2D(2)(previous_block_activation)
        else:
            residual = previous_block_activation
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 1, activation="tanh", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()


In [None]:
# Build model
num_classes = 2
model = get_model(img_size, num_classes)
model.summary()

In [None]:
plot_model(model, show_shapes=True, show_layer_names=True, to_file='model.png')

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
nrm = Normalize(vmin=-1, vmax=1)


In [None]:
img = Image.create()
im = img.render_image().astype(np.float32) / 255
im = np.expand_dims(im, axis=-1)
im.shape

In [None]:
field_and_weights = img.render_field_and_weights()
field_and_weights.shape
im_rgb = nrm(field_and_weights)
plt.imshow(im_rgb[:,:,1], norm=nrm)

In [None]:
pred = model(np.array([im]))

In [None]:
im_b = np.zeros(pred[0,:,:,0].shape, dtype=float)
im_rgb = np.stack([pred[0, :, :, 0],pred[0, :, :, 1], im_b], axis=-1)
im_rgb = nrm(im_rgb)
plt.imshow(im_rgb, norm=nrm)

In [None]:
from diametery.skeleton import get_total_flux
plt.imshow(get_total_flux(pred[0]))

In [None]:
def weighted_l2(y_true, y_pred):
    f_true = y_true[:,:,:,0:2]
    w = y_true[:,:,:,2]
    diff = tf.subtract(f_true, y_pred, name='diff')
    l2 = tf.math.reduce_euclidean_norm(
        diff, axis=-1, keepdims=False, name='euclidean_norm')
    weighted_l2 = tf.multiply(w, l2)
    # return weighted_l2
    return tf.reduce_mean(weighted_l2, axis=[-2,-1])

In [None]:
y_true = img.render_field_and_weights()
y_true = np.array([y_true])
loss = weighted_l2(y_true, pred)
loss.shape

In [None]:
adam = Adam(learning_rate=0.0001)
model.compile(optimizer=adam, loss=weighted_l2)

In [None]:
def epoch_gen(batch_size=8, n_batches=10):
    for _ in range(n_batches):
        X = []
        Y = []
        for _ in range(batch_size):
            img = Image.create()
            x = img.render_image().astype(np.float32) / 255
            x = np.expand_dims(x, axis=-1)
            y = img.render_field_and_weights()
            X.append(x)
            Y.append(y)
        X = np.array(X)
        Y = np.array(Y)
        yield X, Y

In [None]:
from tqdm import tqdm

for x,y in tqdm(epoch_gen()):
    pass

In [None]:
test_x, test_y = zip(*epoch_gen())
test_set_ds = Dataset.from_tensors(
    (tf.constant(test_x), tf.constant(test_y))
)

In [None]:
ds = Dataset.from_generator(
    epoch_gen,
    output_types=(tf.float32, tf.float32),
)

In [None]:
for i in test_set_ds.as_numpy_iterator():
    print(i[1].shape)

In [None]:
model.fit(ds.prefetch(1), epochs=1000, )

In [None]:
model.save("20220313-300epochs")