### Imports

In [1]:
### Standard imports
import os
import random
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf

from glob import glob
from os.path import join
from tensorflow import keras
from tensorflow.keras import layers
from tqdm import tqdm_notebook as tqdm

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ['KMP_DUPLICATE_LIB_OK']='True'

### Model

In [98]:
def _load_model():
    info_input = keras.Input(shape=(130,), name="info")
    prev_input = keras.Input(shape=(9, 9, 9, 2), name="prev")
    next_input = keras.Input(shape=(9, 9, 9, 2), name="next")

    prev_s, next_s = [prev_input], [next_input]
    for layer in [
#         layers.Conv3D(2, (3, 3, 3)),
#         layers.Conv3D(2, (2, 2, 2)),
        layers.Flatten()
    ]:
        prev_s.append(layer(prev_s[-1]))
        next_s.append(layer(next_s[-1]))
    
    info_s = [info_input]
    for layer in [
#         layers.Dense(130),
#         layers.Dense(130),
#         layers.Dense(130),
#         layers.Dense(130)
    ]:
        info_s.append(layer(info_s[-1]))

    x_0 = layers.concatenate([prev_s[-1], next_s[-1], info_s[-1]])
    x_s = [x_0]
    for layer in [
        layers.Dense(994),
        layers.Dense(994),
#         layers.Dropout(0.25),
        layers.Dense(800),
        layers.Dense(800),
#         layers.Dropout(0.25),
        layers.Dense(512),
        layers.Dense(256),
        layers.Dense(256),
        layers.Dense(128),
        layers.Dense( 64),
        layers.Dense( 32),
    ]:
        x_s.append(layer(x_s[-1]))
    
    bold_signal = layers.Dense(1, activation="relu", name="bold_signal")(x_s[-1])

    model = keras.Model(inputs=[prev_input, next_input, info_input], outputs=[bold_signal])
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
          loss={"bold_signal": "mean_squared_error"},
          loss_weights=[1.])
    return model

model = _load_model()

### Setup

In [99]:
history_path = "/Volumes/hd_4tb/results/history.csv"
checkpoint_path = "/Volumes/hd_4tb/results/checkpoints"
training_path = "/Volumes/hd_4tb/results/training"
batch_size  = 256
num_epoches = 1

### Training

In [100]:
bold_signal, prev_volume, next_volume, info = [], [], [], []
for i, batch in enumerate(tqdm(glob(join(training_path, "*")))):
    bold_signal.extend(np.load(join(batch, "pred.npy")))
    prev_volume.extend(np.load(join(batch, "prev.npy")))
    next_volume.extend(np.load(join(batch, "next.npy")))
    info.extend(np.load(join(batch, "info.npy"), allow_pickle=True))
    if (i + 1) % 6 == 0:        
        batch = {"prev": np.array(prev_volume), "next": np.array(next_volume), "info": np.array(info)}
        history = model.fit(
            batch, {'bold_signal': np.array(bold_signal)},
            epochs=num_epoches, batch_size=batch_size, verbose=False,
#             shuffle=True # use when randomly selecting batches
        )
        _std = np.std(bold_signal).round(3)
        _loss = history.history["loss"][0].round(3)
        if _loss > 10:
            print(i)
        data = {"std": [_std], "loss": [_loss]}
        if os.path.isfile(history_path):
            pd.concat([pd.read_csv(history_path), pd.DataFrame(data)]).to_csv(history_path, index=False)
        else:
            pd.DataFrame(data).to_csv(history_path, index=False)
        bold_signal, prev_volume, next_volume, info = [], [], [], []

HBox(children=(IntProgress(value=0, max=654), HTML(value='')))

5
