In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import nibabel as nib
import nipy as ni
import numpy as np
import tensorflow as tf
import pandas as pd
import tensorflow_probability as tfp

tfd = tfp.distributions

import datetime
import shutil
import yaml
import csv
import glob

from hashlib import md5
from matplotlib import pyplot as plt
from tensorflow.keras.layers import (Input, Reshape, Dropout, BatchNormalization, Lambda, Dense, GRU)
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras import backend as K
from multiprocessing import cpu_count
from GPUtil import getFirstAvailable

import warnings
warnings.filterwarnings('ignore')

In [32]:
class conditional_samples(tf.keras.utils.Sequence):
    def __init__(self, config):
        """"""
        self.batch_size = config["batch_size"]
        self.istraining = config["istraining"]
        self.time_step = config["time_step"]

        samples = np.load(os.path.join(config["sample_dir"], "samples.npz"))
        
        self.inputs = samples["inputs"]
        self.outputs = samples["outputs"]
        
        assert len(self.inputs) == len(self.outputs)
        
        self.inputs = self.inputs[:min(config["max_n_samples"], len(self.inputs))]
        self.outputs = self.outputs[:min(config["max_n_samples"], len(self.outputs))]
        
        self.n_samples = len(self.inputs)
        
        assert self.n_samples > 0
        
    def __len__(self):
        if self.istraining:
            return self.n_samples // self.batch_size # drop remainder
        else:
             return np.ceil(self.n_samples / self.batch_size).astype(int)
    
    def __getitem__(self, idx):
        x_batch = self.inputs[idx * self.batch_size:(idx + 1) * self.batch_size, ...]
        y_batch = self.outputs[idx * self.batch_size:(idx + 1) * self.batch_size, ...]
        
        return x_batch, y_batch

In [33]:
def train_model(config):
    
    sample_path = os.path.join(config["sample_dir"], "samples.npz")
    
    hasher = md5()
    hasher.update(open(sample_path, "rb").read())
    hasher.update(str(config["max_n_samples"]).encode())
    hasher.update(str(config["batch_size"]).encode())
    hasher.update(str(config["epochs"]).encode())
    
    save_dir = os.path.join("..", "models", config["model_name"], hasher.hexdigest())
    
    if os.path.exists(save_dir):
        print("This model config has been trained already:\n{}".format(save_dir))
        return
    
    # Define Model Function and Loss
    
    input_shape = np.load(sample_path)["inputs"].shape[1:]
    inputs = Input(shape=input_shape, batch_size=config["batch_size"], name="inputs")
    output_shape = np.load(sample_path)["outputs"].shape[1]

    def model_fn(inputs):
        
        x = GRU(config['hidden_sizes'][0], return_sequences=True)(inputs)
        for hidden_size in config['hidden_sizes'][1:]:
            x = GRU(hidden_size, return_sequences=True)(x)
        x = Dense(output_shape, activation='linear')(x)
        return x

    model = tf.keras.Model(inputs, model_fn(inputs), name=config["model_name"])
    model.summary()
    
    
    
    def negative_log_likelihood(observed_y, predicted_distribution):
        return -K.mean(predicted_distribution.log_prob(observed_y))
    
    # Run Training
    
    train_seq = conditional_samples(config)
    try:
        no_exception = True

        os.makedirs(save_dir, exist_ok=True)

        model.compile(
            optimizer=tf.keras.optimizers.Adam(),
            loss='mean_squared_error'
        )
        train_history = model.fit_generator(
            train_seq,
            epochs=config["epochs"],
            #validation_data=eval_sequence,
            callbacks=[
                TensorBoard(log_dir=save_dir, write_graph=False),
                #ModelCheckpoint("weights.{epoch:02d}-{val_loss:.2f}.hdf5", save_freq="epoch"),
                # EarlyStopping(min_delta=0.05, patience=10, restore_best_weights=True, verbose=1)
            ],
            #validation_steps=1,
            max_queue_size=2*config["batch_size"],
            use_multiprocessing=True,
            workers=cpu_count()
        )
    #except KeyboardInterrupt:
    #    os.rename(save_dir, save_dir + "_stopped")
    #    save_dir = save_dir + "_stopped"
    except Exception as e:
        shutil.rmtree(save_dir)
        no_exception = False
        raise e
    finally:
        if no_exception:
            config_path = os.path.join(save_dir, "config" + ".yml")
            print("Saving {}".format(config_path))
            with open(config_path, "w") as file:
                yaml.dump(config, file, default_flow_style=False)           

            model_path = os.path.join(save_dir, "model.h5")
            print("Saving {}".format(model_path))
            model.save(model_path)
            
    return model

In [36]:
config = dict(
    istraining=True,
    model_name="rnn_pd",
    hidden_sizes=[500, 500],
    time_step=1,
    sample_dir="../subjects/992774/samples/e09d97e5447d8a380d51369e4e4f692f",
    max_n_samples=np.inf,
    batch_size = 128,
    epochs = 10,
)

In [37]:
model = train_model(config)

Model: "rnn_pd"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
inputs (InputLayer)          [(128, 3, 408)]           0         
_________________________________________________________________
gru_14 (GRU)                 (128, 3, 500)             1365000   
_________________________________________________________________
gru_15 (GRU)                 (128, 3, 500)             1503000   
_________________________________________________________________
dense_6 (Dense)              (128, 3, 3)               1503      
Total params: 2,869,503
Trainable params: 2,869,503
Non-trainable params: 0
_________________________________________________________________
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Saving ../models/rnn_pd/cc1276f6bdc33f83d085f09884ced9df/config.yml
Saving ../models/rnn_pd/cc1276f6bdc33f83d085f09884ced9df/model.h5


In [30]:
train_seq.outputs[0 * train_seq.batch_size:(0 + 1) * train_seq.batch_size, ...].shape

(128, 3, 3)