In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import nibabel as nib
import nipy as ni
import numpy as np
import tensorflow as tf
import pandas as pd
import tensorflow_probability as tfp

tfd = tfp.distributions

import datetime
import shutil
import yaml
import csv
import glob

from hashlib import md5
from matplotlib import pyplot as plt
from tensorflow.keras.layers import (Input, Reshape, Dropout, BatchNormalization, Lambda, Dense)
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras import backend as K
from multiprocessing import cpu_count
from GPUtil import getFirstAvailable

import warnings
warnings.filterwarnings('ignore')

In [2]:
class conditional_samples(tf.keras.utils.Sequence):
    def __init__(self, config):
        """"""
        self.batch_size = config["batch_size"]
        self.istraining = config["istraining"]

        samples = np.load(os.path.join(config["sample_dir"], "samples.npz"))
        
        self.inputs = samples["inputs"]
        self.outputs = samples["outputs"]
        
        assert len(self.inputs) == len(self.outputs)
        
        self.inputs = self.inputs[:min(config["max_n_samples"], len(self.inputs))]
        self.outputs = self.outputs[:min(config["max_n_samples"], len(self.outputs))]
        
        self.n_samples = len(self.inputs)
        
        assert self.n_samples > 0
        
    def __len__(self):
        if self.istraining:
            return self.n_samples // self.batch_size # drop remainder
        else:
             return np.ceil(self.n_samples / self.batch_size).astype(int)
    
    def __getitem__(self, idx):
        x_batch = np.vstack(self.inputs[idx * self.batch_size:(idx + 1) * self.batch_size])
        y_batch = np.vstack(self.outputs[idx * self.batch_size:(idx + 1) * self.batch_size])
        
        return x_batch, y_batch

In [3]:
def train_model(config):
    
    sample_path = os.path.join(config["sample_dir"], "samples.npz")
    
    hasher = md5()
    hasher.update(open(sample_path, "rb").read())
    hasher.update(str(config["max_n_samples"]).encode())
    hasher.update(str(config["batch_size"]).encode())
    hasher.update(str(config["epochs"]).encode())
    
    save_dir = os.path.join("..", "models", config["model_name"], hasher.hexdigest())
    
    if os.path.exists(save_dir):
        print("This model config has been trained already:\n{}".format(save_dir))
        return
    
    # Define Model Function and Loss
    
    input_shape = np.load(sample_path)["inputs"].shape[1:]
    inputs = Input(shape=input_shape, name="inputs")

    def model_fn(inputs):

        x = Dense(1024, activation="relu")(inputs)

        x = Dense(1024, activation="relu")(x)

        mu = Dense(512, activation="relu")(x)
        mu = Dense(3, activation="linear")(mu)
        mu = Lambda(lambda t: K.l2_normalize(t, axis=-1), name="mu")(mu)

        kappa = Dense(512, activation="relu")(x)
        kappa = Dense(1, activation="relu")(kappa)
        kappa = Lambda(lambda t: K.squeeze(t, 1), name="kappa")(kappa) 
        
        return tfp.layers.DistributionLambda(
            make_distribution_fn=lambda params: tfd.VonMisesFisher(mean_direction=params[0],
                                                                   concentration=params[1]),
            convert_to_tensor_fn=tfd.Distribution.mean
        )([mu, kappa])

    model = tf.keras.Model(inputs, model_fn(inputs), name=config["model_name"])
    model.summary()
    
    def negative_log_likelihood(observed_y, predicted_distribution):
        return -K.mean(predicted_distribution.log_prob(observed_y))
    
    # Run Training
    
    train_seq = conditional_samples(config)
    try:
        no_exception = True

        os.makedirs(save_dir, exist_ok=True)

        model.compile(
            optimizer=tf.keras.optimizers.Adam(),
            loss=negative_log_likelihood
        )
        train_history = model.fit_generator(
            train_seq,
            epochs=config["epochs"],
            #validation_data=eval_sequence,
            callbacks=[
                TensorBoard(log_dir=save_dir, write_graph=False),
                #ModelCheckpoint("weights.{epoch:02d}-{val_loss:.2f}.hdf5", save_freq="epoch"),
                # EarlyStopping(min_delta=0.05, patience=10, restore_best_weights=True, verbose=1)
            ],
            #validation_steps=1,
            max_queue_size=2*config["batch_size"],
            use_multiprocessing=True,
            workers=cpu_count()
        )
    #except KeyboardInterrupt:
    #    os.rename(save_dir, save_dir + "_stopped")
    #    save_dir = save_dir + "_stopped"
    except Exception as e:
        shutil.rmtree(save_dir)
        no_exception = False
        raise e
    finally:
        if no_exception:
            config_path = os.path.join(save_dir, "config" + ".yml")
            print("Saving {}".format(config_path))
            with open(config_path, "w") as file:
                yaml.dump(config, file, default_flow_style=False)           

            model_path = os.path.join(save_dir, "model.h5")
            print("Saving {}".format(model_path))
            model.save(model_path)
            
    return model

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [5]:
config = dict(
    istraining=True,
    model_name="entrack_conditional",
    sample_dir="../subjects/992774/samples/fa7c02604b92de5f32cd3b61dbc2f8b7",
    max_n_samples=np.inf,
    batch_size = 128,
    epochs = 1,
)

In [7]:
model = train_model(config)

W1014 13:01:35.505799 139703451047680 deprecation.py:323] From /local/home/vwegmayr/miniconda2/envs/thesis/lib/python3.6/site-packages/tensorflow_probability/python/distributions/von_mises_fisher.py:312: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Model: "entrack_conditional"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inputs (InputLayer)             [(None, 18)]         0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 1024)         19456       inputs[0][0]                     
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1024)         1049600     dense[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 512)          524800      dense_1[0][0]                    
________________________________________________________________________________