In [4]:
import datetime
from glob import glob
import math
import matplotlib.pyplot as plt
import nibabel as nb
from nilearn import image
import numpy as np
import os
import pandas as pd
import random
import tensorflow as tf
import timeit
import warnings

In [None]:
from source_code.data_io import Dataset_Pipeline, _get_data

devices = ['/gpu:0', '/gpu:1']

class_type=True #true if QC, false if site

if class_type:
    import source_code.models.basic_qc_cnn as model
    train_cache_prefix="/home/smantra/finalproject/cache_train_qc/"
    eval_cache_prefix="/home/smantra/finalproject/cache_eval_qc/"
    d = devices[0]
else:
    import source_code.models.basic_site_cnn as model
    train_cache_prefix="/home/smantra/finalproject/cache_train_sites/"
    eval_cache_prefix="/home/smantra/finalproject/cache_eval_sites/"
    d = devices[1]

tf.logging.set_verbosity(tf.logging.INFO)

if __name__ == '__main__':
    log_dir = "logs"
    current_run_subdir = os.path.join(
        "run_" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
#    model_dir = os.path.join(log_dir, model.name, "106x128x110")#current_run_subdir)
    model_dir = os.path.join(log_dir, model.name, '592018')

    run_config = tf.estimator.RunConfig(model_dir=model_dir)

    params = tf.contrib.training.HParams(
        target_shape=(106, 128, 110),
        model_dir=model_dir
    )

    ds = Dataset_Pipeline(target_shape=params.target_shape,
                          n_epochs=10,
                          train_src_folder="/home/smantra/finalproject/data/",
                          train_cache_prefix="/home/smantra/finalproject/cache_train/",
                          eval_src_folder="/home/smantra/finalproject/eval/",
                          eval_cache_prefix="/home/smantra/finalproject/cache_eval/",
                          batch_size=4
                         )

    # Workaround for cache iterator concurency issues. Iterate over the whole
    # training dataset without counterbalancing to make sure everything is
    # preprocessed and cached
    if not os.path.exists(ds.train_cache_prefix + ".index"):
        print("Preprocessing the training set")
        with tf.Session() as sess:
            train_dataset = _get_data(batch_size=ds.batch_size,
                                      src_folder=ds.train_src_folder,
                                      n_epochs=1,
                                      cache_prefix=ds.train_cache_prefix,
                                      shuffle=False,
                                      target_shape=params.target_shape,
                                     )

            train_dataset = train_dataset.make_one_shot_iterator()
            while True:
                try:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        features, (qc_labels, site_labels) = sess.run(train_dataset.get_next())

                except tf.errors.OutOfRangeError:
                    break
        print("Finished preprocessing the training set")

    train_spec = tf.estimator.TrainSpec(input_fn=ds.train_input_fn,
                                        max_steps=20000,
                                       )
    eval_spec = tf.estimator.EvalSpec(input_fn=ds.eval_input_fn,
                                      steps=None,
                                      start_delay_secs=0,
                                      throttle_secs=600)

    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       params=params,
                                       config=run_config)
    
    config = tf.ConfigProto() 
    config.gpu_options.allow_growth = True 
    with tf.Session(config=config) as sess:
        with tf.device(d):
            sess.run(tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec))

INFO:tensorflow:Using config: {'_model_dir': 'logs/basic_qc_cnn/592018', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f173439e208>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after 600 secs (eval_spec.throttle_secs) or training is finished.
<TensorSliceDataset shapes: ((),), types: (tf.string,)>
INFO:tensorflow:Calling model_fn.
Tensor("Shape:0", shape=(5,), dtype=int32)
Tensor("Shape_1:0", shap