Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNN.evaluate(x,y) → "ValueError: Cannot use the given session to evaluate tensor" #966

Open
bwllc opened this issue Nov 26, 2017 · 5 comments

Comments

@bwllc
Copy link

bwllc commented Nov 26, 2017

Hello again!

While I'm waiting to hear back from folks regarding the compatibility of TFLearn with scikit-learn's cross validation tools (Issue #965), I've decided to try building a cross-validation system myself. In doing so, I've encountered a more basic problem.

Here's a somewhat abridged version of my code:

window = "whole"
encoding = "one_hot"

db = Database("/home/bw/Documents/compact")
traindb, testdb = db.train_test_split()
train = traindb.values(window, encoding)
test = testdb.values(window, encoding)

topology = 60,40,20
model = MyDNNSubclass(window, topology)
model.fit(*train, validation_set=test, n_epoch=40, batch_size=3)
print("Run completed.  Minimum val_loss = {:.4f}.". format(model.min_val_loss))

inp, tgt = test
print("inp", inp.shape)
print("tgt", tgt.shape)
pred = model.predict(inp)
print("pred", pred.shape)
score = model.evaluate(inp, tgt)
print("score", score)

NOTE 1: MyDNNSubclass is, as should be pretty obvious from its name, a subclass of TFLearn's DNN class. The subclass does NOT override evaluate().

NOTE 2: I have a Callback attached to the fit() method in MyDNNSubclass which performs early stopping. So even though I specify 40 epochs, that's an upper limit, and the fitting process usually does not run that long.

Here's a somewhat abridged version of a typical output:

/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6

Warning:tensorflow:From /usr/local/lib/python3.6/dist-packages/tflearn/initializations.py:119: UniformUnitScaling.__init__ (from tensorflow.python.ops.init_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.initializers.variance_scaling instead with distribution=uniform to get equivalent behavior.
2017-11-25 20:45:10.883990: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2017-11-25 20:45:10.977634: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Found device 0 with properties: 
name: GeForce GTX 760 major: 3 minor: 0 memoryClockRate(GHz): 1.176
pciBusID: 0000:04:00.0
totalMemory: 1.94GiB freeMemory: 1.64GiB
2017-11-25 20:45:10.977669: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 760, pci bus id: 0000:04:00.0, compute capability: 3.0)
2017-11-25 20:45:12.079002: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 760, pci bus id: 0000:04:00.0, compute capability: 3.0)

Run id: IWJV3S
Log directory: /tmp/tflearn_logs/

Training samples: 266
Validation samples: 89

Training Step: 89  | total loss: 1.12626 | time: 2.585s
| Adam | epoch: 001 | loss: 1.12626 | val_loss: 1.16531 -- iter: 266/266

Training Step: 178  | total loss: 1.12943 | time: 1.543s
| Adam | epoch: 002 | loss: 1.12943 | val_loss: 1.19077 -- iter: 266/266

(snip)

Training Step: 1513  | total loss: 1.00457 | time: 1.544s
| Adam | epoch: 017 | loss: 1.00457 | val_loss: 1.06024 -- iter: 266/266

Training Step: 1602  | total loss: 1.00449 | time: 1.540s
| Adam | epoch: 018 | loss: 1.00449 | val_loss: 1.05571 -- iter: 266/266

Run completed.  Minimum val_loss = 1.0557.
inp (89, 398, 20)
tgt (89, 398, 3, 2)
pred (89, 398, 3, 2)

Traceback (most recent call last):
  File "exp54.py", line 66, in <module>
    score = model.evaluate(inp, tgt)
  File "/usr/local/lib/python3.6/dist-packages/tflearn/models/dnn.py", line 370, in evaluate
    return self.predictor.evaluate(feed_dict, ops, batch_size)
  File "/usr/local/lib/python3.6/dist-packages/tflearn/helpers/evaluator.py", line 95, in evaluate
    tflearn.is_training(False, self.session)
  File "/usr/local/lib/python3.6/dist-packages/tflearn/config.py", line 97, in is_training
    tf.get_collection('is_training_ops')[1].eval(session=session)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 570, in eval
    return _eval_using_default_session(self, feed_dict, self.graph, session)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 4452, in _eval_using_default_session
    raise ValueError("Cannot use the given session to evaluate tensor: "
ValueError: Cannot use the given session to evaluate tensor: the tensor's graph is different from the session's graph.

=====
(program exited with code: 1)
Press return to continue

You can see that my model initializes, and it fits. I've examined the output, and I'm getting learning. I can pass input to model.predict(), and it returns a tensor of the correct shape (matches the target).

But in order to build cross-validation myself, I need to compute the loss with cross-validation folds. For now, I'm just trying my test set. But I need model.evaluate() to run, and that's where I'm getting a failure.

Do I have to do something with a tensorflow.Session()? I thought that the Session object management would be handled by TFLearn code. I am looking at the TFLearn source and trying to figure it out.

Yes, I have some warnings. Perhaps they are relevant.

Thanks for any suggestions!

@EllyMandliel
Copy link

Could you provide the wrapper class code (or some of it)?

Are you using multithreading?

@bwllc
Copy link
Author

bwllc commented Dec 11, 2017

Thanks for your reply, EllyMandiel!

As far as I know, I am not using multithreading. I only have one GPU, and my TensorFlow build is configured to use the GPU.

You may regret asking me for my tfl.DNN subclass code, because there's a lot in there which concerns the specifics of my project (protein folding). I hope it isn't too distracting. I am attaching a copy below.

import tensorflow as tf
import tflearn as tfl

DIM = {"one_hot" : 20, "physicochemical" : 4}

##========

def l2_angle_distance(predict, actual):
    """ 
    Customized loss function for backbone angle calculations in protein
    folding.  Cosine distances are calculated, and then squared, to
    give the loss function quadratic (L2) behavior.  NaN values are masked.
    """
    with tf.name_scope("L2AngleDistance"):
        # Calculate a scaling factor: the number of FINITE values in 
        # actual, must be typecast to a tf.float32.
        count = actual[...,0,0]
        scale = tf.to_float(tf.count_nonzero(tf.is_finite(count)))
        # Mask actual: change all NaN in actual to whatever was predicted.
        # Stops NaN from propagating into the loss calculation.  This will
        # instead calculate cosine distances of zero for any rows where 
        # actual was NaN.
        actual = tf.where(tf.is_nan(actual), predict, actual)
        # Supply the -1 argument for axis (that TFLearn can't pass in)
        losses = tf.losses.cosine_distance(predict, actual, -1, 
                 reduction=tf.losses.Reduction.NONE)
        # Square the losses, then sum, to get L2 scalar loss.
        # Finally, divide the loss result by the scaling factor.
        return tf.reduce_sum(losses * losses) / scale

##========

class MyDNNSubclass(tfl.DNN):

    def __init__(self, window, topology, encoding="one_hot"):
        self.window = window
        self.topology = topology
        self._keep_prob = 2/3  # Make this a hyperparameter?
        if encoding not in ("one_hot", "physicochemical"):
            raise ValueError("Encoding must be 'one_hot' or 'physicochemical'.")
        self.encoding = encoding
        if window == "whole_protein":
            net = self._convolutional_net(topology, DIM[encoding])
        else:
            if (window < 3) or (not window % 2):
                raise ValueError("Window must be 'whole_protein', or an odd integer, at least 3.")
            net = self._dense_net(window, topology, DIM[encoding])
        net = tfl.regression(net, optimizer="adam", learning_rate=0.001, loss=l2_angle_distance)
        super().__init__(net) # Checkpoints not working?  Verbosity not working?
    
    def _convolutional_net(self, topology, width):
        # Input placeholder
        net = tfl.input_data(shape=[None, None, width])
        # Need to keep two dimensions of input shape, see below
        sh = tf.shape(net)
        # Add all trainable layers to the graph, except the last.  For
        # now, every layer convolves with a window of 3.
        for size in topology:
            net = tfl.conv_1d(net, size, 3)
            net = tfl.dropout(net, keep_prob=self._keep_prob)
        # The last trainable layer always has a fixed size of six outputs.
        net = tfl.conv_1d(net, 6, 1)
        # Output layers, reshape and normalize, not trainable.
        net = tfl.reshape(net, [sh[0], sh[1], 3, 2])
        net = tf.nn.l2_normalize(net, dim=3)
        return net
    
    def _dense_net(self, window, topology, width):
        # Input placeholder
        net = tfl.input_data(shape=[None, window*width])
        # Add all trainable layers to the graph, except the last.
        for size in topology:
            net = tfl.fully_connected(net, size)
            net = tfl.dropout(net, keep_prob=self._keep_prob)
        # The last trainable layer always has a fixed size of six outputs.
        net = tfl.fully_connected(net, 6)
        # Output layers, reshape and normalize, not trainable.
        net = tfl.reshape(net, [-1,3,2])
        net = tf.nn.l2_normalize(net, dim=2)
        return net
    
    def fit(self, *args, **kwargs):
        self.restarts = 0
        tf.reset_default_graph()  # https://github.com/tflearn/tflearn/issues/408
        cb = MonitorCallback()
        while True:
            try:
                super().fit(*args, **kwargs, callbacks=cb)
            except NonFiniteResult:
                # (See Note 1)
                self.restarts += 1
                print("> NON FINITE result, restart #{} <\n".format(self.restarts))
                continue
            except StopIteration:
                pass
            break
        self.epoch = cb.epoch
        self.min_val_loss = cb.min_val_loss
    
##========

class NonFiniteResult(Exception):
    pass

##========

class MonitorCallback(tfl.callbacks.Callback):
    
    # https://github.com/tflearn/tflearn/blob/master/tflearn/helpers/trainer.py
    # See lines 1059 - 1079.
    def __init__(self, long=8, short=4):
        self.val_loss = np.zeros((0,))
        self.lengths = -long, -short
        self.epoch = 0
    
    def on_epoch_end(self, training_state):
        self.epoch += 1
        if not np.isfinite(training_state.global_loss):
            raise NonFiniteResult
        self.val_loss = np.append(self.val_loss, training_state.val_loss)
        long_mean, short_mean = [self.val_loss[x:].mean() for x in self.lengths]
        if long_mean / short_mean < 0.999:
            self.finish(training_state)
            raise StopIteration
    
    def on_train_end(self, training_state):
        self.finish(training_state)
    
    def finish(self, training_state):
        self.min_val_loss = self.val_loss.min()

@EllyMandliel
Copy link

Can you try to override evaluate, using the base class' evaluate, but using:
tf.reset_default_graph()

@bwllc
Copy link
Author

bwllc commented Dec 13, 2017

Thanks again, Elly. I tried your suggestions. I added the following to MyDNNSubclass:

    def evaluate(self, *args, **kwargs):
        # See https://github.com/tflearn/tflearn/issues/966
        print("\n** Overriding evaluate(). **\n")
        tf.reset_default_graph()
        return super().evaluate(*args, **kwargs)

Other than printing my override notice and the part of the traceback that corresponds to my override code, everything is the same. I'm still ending with...

ValueError: Cannot use the given session to evaluate tensor: the tensor's graph is different from the session's graph.

So, nothing has changed.

@bwllc
Copy link
Author

bwllc commented Jan 16, 2018

Hello,

So it has been a month since EllyMandiel attempted to assist me, and we reached a dead end. I believe that TFLearn's DNN.evaluate() is mis-handling TensorFlow Session objects (see documentation above).

In the intervening time span, I almost managed to port my entire project over to raw TensorFlow. When I hit a wall there, I tried Keras. I have Keras working, but my models aren't training to same scores that I achieved with TFLearn. I'm not yet sure why, but I think it has something to do with the way that Keras processes batches of data. TFLearn was performing better, and training faster too.

If anyone can help me understand my TFLearn bug, I would deeply appreciate it. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants