### Imports

In [2]:
### Standard imports
import os
import random
import datetime
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf

from glob import glob
from os.path import join
from random import shuffle
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import LeakyReLU
from tqdm import tqdm_notebook as tqdm

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ['KMP_DUPLICATE_LIB_OK']='True'

### Model

In [3]:
def _load_model():
    info_input = keras.Input(shape=(130,), name="info")
    prev_input = keras.Input(shape=(9, 9, 9, 2), name="prev")
    next_input = keras.Input(shape=(9, 9, 9, 2), name="next")

    prev_s, next_s = [prev_input], [next_input]
    for layer in [
        layers.Conv3D(2, (3, 3, 3), use_bias=False),
        layers.Conv3D(2, (2, 2, 2), use_bias=False),
        layers.Flatten()
    ]:
        prev_s.append(layer(prev_s[-1]))
        next_s.append(layer(next_s[-1]))
    
    info_s = [info_input]
    for layer in [
        layers.Dense(130, activation=LeakyReLU(alpha=0.05)),
        layers.Dense(130, activation=LeakyReLU(alpha=0.05)),
#         layers.Dense(130),
#         layers.Dense(130)
    ]:
        info_s.append(layer(info_s[-1]))

    x_0 = layers.concatenate([prev_s[-1], next_s[-1], info_s[-1]])
#     x_0 = layers.concatenate([prev_s[-1], next_s[-1]])
    x_s = [x_0]
    for layer in [
#         layers.Dense(994, activation=LeakyReLU(alpha=0.05)),
        layers.Dense(994, activation="relu"),
#         layers.Dropout(0.25),
        layers.Dense(800, activation="relu"),
#         layers.Dense(800),
#         layers.Dropout(0.25),
#         layers.Dense(512, activation="sigmoid"),
#         layers.Dense(256, activation="sigmoid"),
#         layers.Dense(256, activation="sigmoid"),
#         layers.Dense(128, activation=LeakyReLU(alpha=0.05)),
#         layers.Dense( 64, activation=LeakyReLU(alpha=0.05)),
        layers.Dense( 32, activation="relu"),
    ]:
        x_s.append(layer(x_s[-1]))
    
    bold_signal = layers.Dense(1, activation="sigmoid", name="bold_signal")(x_s[-1])

    model = keras.Model(inputs=[prev_input, next_input, info_input], outputs=[bold_signal])
#     model = keras.Model(inputs=[prev_input, next_input], outputs=[bold_signal])

    model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-5),
      loss={"bold_signal": "mse"},
      loss_weights=[1.])
    return model

log_dir="./logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

model = _load_model()

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


### Setup

In [4]:
history_path = "/Volumes/hd_4tb/results/history.csv"
checkpoint_path = "/Volumes/hd_4tb/results/checkpoints"
training_path = "/Volumes/hd_4tb/results/training"
batch_size  = 128
num_epoches = 3

if os.path.isfile(history_path):
    os.remove(history_path)

### Training

In [None]:
training = glob(join(training_path, "*", "*"))
# shuffle(training)

bold_signal, prev_volume, next_volume, info = [], [], [], []
for i, batch_path in enumerate(tqdm(training)):
    bold_signal.extend(np.load(join(batch_path, "pred.npy")))
    prev_volume.extend(np.load(join(batch_path, "prev.npy")))
    next_volume.extend(np.load(join(batch_path, "next.npy")))
    info.extend(np.load(join(batch_path, "norm_info.npy"), allow_pickle=True))
    if (i + 1) % 10 == 0:        
        batch = {
            "prev": np.array(prev_volume),
            "next": np.array(next_volume),
            "info": np.array(info, dtype=np.float32)
        }
        bold_signal = np.array(bold_signal)
        assert not np.sum(np.isnan(bold_signal))
        for _k in batch:
            assert not np.sum(np.isnan(batch[_k]))
        history = model.fit(
            batch, {'bold_signal': bold_signal},
            epochs=num_epoches, batch_size=batch_size, verbose=False,
            callbacks=[tensorboard_callback],
            shuffle=True # use when randomly selecting batches
        )
        preds = model.predict(batch, batch_size=batch_size)
#         print(len(set([p[0] for p in preds])), set([p[0] for p in preds]))
#         print()
        _std = np.std(bold_signal).round(3)
        _loss = history.history["loss"][0].round(3)
#         print(_loss)
        assert not np.isnan(_loss), batch_path
        prev_weights = model.get_weights()
        data = {"std": [_std] * num_epoches, "loss": history.history["loss"]}
        if os.path.isfile(history_path):
            pd.concat([pd.read_csv(history_path), pd.DataFrame(data)]).to_csv(history_path, index=False)
        else:
            pd.DataFrame(data).to_csv(history_path, index=False)
        bold_signal, prev_volume, next_volume, info = [], [], [], []

HBox(children=(IntProgress(value=0, max=13357), HTML(value='')))


KeyboardInterrupt






In [41]:
model.save("/Volumes/hd_4tb/results/just_voxels.h5")