In [4]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np
from preprocessing import get_preprocessed_dataset

In [11]:
(
X_train,
gender_train,
age_train,
X_test,
gender_test,
age_test ) = get_preprocessed_dataset("/data", n_max = 200)

FileNotFoundError: [Errno 2] No such file or directory: 'data/fold_0_data.txt'

In [None]:
# function to create standard CNN network
# idea: expand function such that its easy to change the architechture
def CNN_classic():
    model = keras.Sequential([
        layers.Rescaling(1./255),
        layers.Conv2D(16, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(32, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(1, activation = 'sigmoid')], 
        name = "CNN_classic"
    )
    return model

# function to create the architechture of the multitask network
def CNN_multitask(img_size):

    inputs = tf.keras.layers.Input(shape=(img_size, img_size, 3), name='input')

    main_branch = tf.keras.layers.Conv2D(16, 3, padding = "same", activation="relu")(inputs)
    main_branch = tf.keras.layers.MaxPooling2D()(main_branch)
    main_branch = tf.keras.layers.Conv2D(32, 3, padding = "same", activation="relu")(main_branch)
    main_branch = tf.keras.layers.MaxPooling2D()(main_branch)
    main_branch = tf.keras.layers.Conv2D(64, 3, padding = "same", activation="relu")(main_branch)
    main_branch = tf.keras.layers.Flatten()(main_branch)
    main_branch = tf.keras.layers.Dense(128, activation='relu')(main_branch)

    task_1_branch = tf.keras.layers.Dense(256, activation='relu')(main_branch)
    task_1_branch = tf.keras.layers.Dense(128, activation='relu')(task_1_branch)
    task_1_branch = tf.keras.layers.Dense(1, activation='sigmoid', name='gender')(task_1_branch)

    task_2_branch = tf.keras.layers.Dense(256, activation='relu')(main_branch)
    task_2_branch = tf.keras.layers.Dense(128, activation='relu')(task_2_branch)
    task_2_branch = tf.keras.layers.Dense(8, activation='softmax', name='age')(task_2_branch)

    model = tf.keras.Model(inputs = inputs, outputs = [task_1_branch, task_2_branch], name = "Multitask CNN")
    
    return model

# create dict of the data split. tool to train all models in same function
def create_data_dict(X_train, y_train, X_test, y_test):
    return {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test}

# fit all the models to their respective datasets
# ** More arguments to adjust fitting procedure **
def fit_models(model_data, no_epochs, verbose = 0):

    histories = {}
    for model,data in model_data.items():
        print(f"Training on Model: {model.name}")
        X_train, y_train, X_test, y_test = data.values()
        
        histories[model.name] = model.fit(
                                    X_train,
                                    y_train,
                                    epochs = no_epochs,
                                    verbose = verbose,
                                    validation_data = (X_test,y_test))
        print()
    return histories


def compile_model(model, loss, optimizer = 'adam', metrics = ['accuracy']):
    model.compile(
        optimizer = optimizer,
        loss = loss,
        metrics = metrics)


In [None]:
model_classic = CNN_classic()
compile_model(model_classic, keras.losses.BinaryCrossentropy())

model_multitask = CNN_multitask(224)
multitask_loss = {'gender': keras.losses.BinaryCrossentropy(),
                'age': keras.losses.SparseCategoricalCrossentropy()}
compile_model(model_multitask, multitask_loss)


model_data = {}
model_data[model_classic] = create_data_dict(X_train, gender_train, X_test, gender_test)
model_data[model_multitask] =  create_data_dict(X_train, (gender_train, age_train), X_test, (gender_test, age_test))

fit_models(model_data, 5)
              

NameError: name 'CNN_classic' is not defined