# Using Keras for Model Fusion
* Chris Arnold (Cardiff University)
* DS3 Data Science Summer School
* August 25th, 2023


In [None]:
# Housekeeping
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

## Understanding Functional and Sequential Models in Keras

In [None]:
# Quick recap: a Sequential Neural Regression Model
model_seq = keras.Sequential()
model_seq.add(layers.Dense(64, input_shape=(10,), activation="relu"))
model_seq.add(layers.Dense(32, activation="relu"))
model_seq.add(layers.Dense(1, activation="linear"))

# Take a look at the model
model_seq.summary()

In [None]:
# A Functional Neural Regression Model
# Input
inputs = keras.Input(shape=(10,))

# Model
x = layers.Dense(64, activation="relu")(inputs)
x = layers.Dense(32, activation="relu")(x)
x = layers.Dense(1, activation="linear")(x)

# Create
model_fun = keras.Model(inputs, x)

model_fun.summary()

In [None]:
model_seq.summary()

## A Fusion Model in Keras

In [None]:
# Generate two inputs
inputA = keras.Input(shape=(32,))
inputB = keras.Input(shape=(128,))

# First branch operates on the first input
x = layers.Dense(8, activation="relu")(inputA)
x = layers.Dense(4, activation="relu")(x)
x = keras.Model(inputs=inputA, outputs=x)

# Second branch opreates on the second input
y = layers.Dense(64, activation="relu")(inputB)
y = layers.Dense(32, activation="relu")(y)
y = layers.Dense(4, activation="relu")(y)
y = keras.Model(inputs=inputB, outputs=y)

# Combine the output of the two branches with a concatenation
combined = layers.concatenate([x.output, y.output])

# Fully connected layer and regression prediction on the combined outputs
z = layers.Dense(2, activation="relu")(combined)
z = layers.Dense(1, activation="linear")(z)

# The model accept the inputs of the two branches and
# then returns the prediction
model_fusion = keras.Model(inputs=[x.input, y.input], outputs=z)

In [None]:
model_fusion.summary()

Sources:
* https://keras.io/guides/functional_api/
* https://pyimagesearch.com/2019/02/04/keras-multiple-inputs-and-mixed-data/