Skip to content
This repository has been archived by the owner on Oct 19, 2019. It is now read-only.

Retrain resnet model on new data. #12

Open
SilviaLauraPintea opened this issue Aug 8, 2016 · 5 comments
Open

Retrain resnet model on new data. #12

SilviaLauraPintea opened this issue Aug 8, 2016 · 5 comments

Comments

@SilviaLauraPintea
Copy link

Hi,

I am trying to add a new fc layer on top of the avgpool resnet layer with a different number of outputs to suit my problem.
I do not want to only retrain the new fc but also the previous layers. So I need the gradients of the previous layers as well. Unfortunately this does not seem to work.
I have tried on a dummy net that I have created to save it (without the gradients -- so similar to the provided resnet meta and ckpt) and then load it and add a new fc layer and this worked without problems.

Here is a snapshot of my retraining code:

# Start the session:
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

# Gets data batches.
trainimages, trainlabels = dataAsTensors(is_training=True, batch_size=FLAGS.batch_size)

# In the default graph:
graph = tf.get_default_graph()
with graph.as_default():    

    # Data saver loading the graph meta only.
    dataSaver = tf.train.import_meta_graph('ResNet-L50.meta')

    for op in graph.get_operations():
        print op.name

    # Get both the 'avg_pool' and the 'images' operations.
    images = graph.get_tensor_by_name("images:0") 
    avgpool = graph.get_tensor_by_name('avg_pool:0')  

    # Define a new fc layer on top of the avg_pool layer 
    logits, _ = fc_num_outs(avgpool, FLAGS.num_classes, FLAGS.avgpool_size)    

    # Define the loss on top of the new fc and a placeholder for the labels 
    labelsVar = tf.placeholder(tf.int64, shape=(FLAGS.batch_size), name='labelsVar')
    loss_ = loss(logits, labelsVar)

    # Define the gradients and get the operation.
    global_step = tf.Variable(0, name='global_step', trainable=False)    
    ops = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = ops.minimize(loss_, global_step=global_step)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord=coord)
    with sess.as_default():

        # Initialize all variables.
        sess.run(tf.initialize_all_variables())

        # Restore the RESNET checkpoint after initialization.
        dataSaver.restore(sess, "ResNet-L50.ckpt")

        for i in range(0, FLAGS.max_steps):
            # Feed the batch images and the labels.
            npImages = trainimages.eval()
            npLabels = trainlabels.eval()

            # Run 1 step of the gradient optimization.
            sess.run(train_op, {images: npImages, labelsVar: npLabels})
            print "Done running grad step.. ", i

            if (i % 100 == 0): # Save the checkpoint
                dataSaver.save(sess, 'resnet_retrained' + str(i) + '.ckpt')

    coord.request_stop()
    coord.join(threads)
    sess.close()

I am not sure why for the resnet model I get this error:

File "retrain.py", line 278, in main
retrain()
File "retrain.py", line 244, in retrain
train_op = ops.minimize(loss_, global_step=global_step)
File "tensorflow/python/training/optimizer.py", line 193, in minimize grad_loss=grad_loss)
File "tensorflow/python/training/optimizer.py", line 250, in compute_gradients colocate_gradients_with_ops=colocate_gradients_with_ops)
File "tensorflow/python/ops/gradients.py", line 467, in gradients out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
File "tensorflow/python/ops/control_flow_ops.py", line 1047, in ZerosLikeOutsideLoop pred = op_ctxt.pred
AttributeError: 'NoneType' object has no attribute 'pred'

while for my own toy model the same code seems to work.

Thanks a lot.
Cheers,
Silvia

@KalraA
Copy link

KalraA commented Feb 10, 2017

I'm getting the same error!
@SilviaLauraPintea
Have you found a solution?

@leiup
Copy link

leiup commented Mar 5, 2017

How to solve this problem? I also meet the same error!
@SilviaLauraPintea @KalraA @ry
Thank you very much~

@nikste
Copy link

nikste commented Mar 24, 2017

Anybody solved this?

@nu1ptr
Copy link

nu1ptr commented Jun 11, 2017

I'm having this issue as well. Has anybody been able to figure it out yet?

@rener1199
Copy link

I also meet the same error! How to solve that?
@SilviaLauraPintea @KalraA @ry @nu1ptr
Thanks

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants