Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to load a pretrained model? #12

Closed
lijiaman opened this issue Oct 27, 2016 · 4 comments
Closed

how to load a pretrained model? #12

lijiaman opened this issue Oct 27, 2016 · 4 comments

Comments

@lijiaman
Copy link

I want to load the pretrained AlexNet model bvlc_alexnet.npy, and I know using np.load('bvlc_alexnet.npy') can get some numpy array of weights and bias. (Like net = np.load('bvlc_alexnet.npy'), net["conv1"][0] represent the weights of conv1) But there are some problems loading them into Conv2d layer.

image
image

Then while running, it gives error:
Variable cannot be callable.

Could someone give an example to load weights of pretrained model?By the way, I tried to install TensorLayer from git, while inputting command "pip install . -e", it said -e need an argument.. so how to install it properly if I want to modify some codes of it? Hope someone can help me. Thanks very much.

@zsdonghao
Copy link
Member

zsdonghao commented Oct 28, 2016

Hi, you should use pip install -e ..
In term of loading model, I think you can find the answer from here : http://tensorlayer.readthedocs.io/en/latest/user/more.html#fqa

@lijiaman
Copy link
Author

Thanks~

@wagamamaz
Copy link
Collaborator

@lijiaman your W_int ann b_int are wrong, this is not the way to use TensorLayer.

The network.all_params is the list of all parameters in network.

you can find an example to load a list of numpy array into tensorlayer as follow:

  1. http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/files.html#load_npz
  2. main_restore_embedding_layer()
    https://github.com/zsdonghao/tensorlayer/blob/master/tutorial_generate_text.py

@zsdonghao
Copy link
Member

This is a good example provided by @wagamamaz

#! /usr/bin/python
# -*- coding: utf8 -*-

import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import set_keep
import time

is_test_only = False # if True, restore and test without training

X_train, y_train, X_val, y_val, X_test, y_test = \
                tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))

X_train = np.asarray(X_train, dtype=np.float32)[0:10000]#<-- small training set for fast debugging
y_train = np.asarray(y_train, dtype=np.int64)[0:10000]
X_val = np.asarray(X_val, dtype=np.float32)
y_val = np.asarray(y_val, dtype=np.int64)
X_test = np.asarray(X_test, dtype=np.float32)
y_test = np.asarray(y_test, dtype=np.int64)

print('X_train.shape', X_train.shape)
print('y_train.shape', y_train.shape)
print('X_val.shape', X_val.shape)
print('y_val.shape', y_val.shape)
print('X_test.shape', X_test.shape)
print('y_test.shape', y_test.shape)
print('X %s   y %s' % (X_test.dtype, y_test.dtype))

sess = tf.InteractiveSession()

batch_size = 128
x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
y_ = tf.placeholder(tf.int64, shape=[None,])

def inference(x, is_train, reuse=None):
    # gamma_init = tf.random_normal_initializer(1., 0.02)
    with tf.variable_scope("CNN", reuse=reuse):
        tl.layers.set_name_reuse(reuse)
        network = tl.layers.InputLayer(x, name='input_layer')

        network = tl.layers.Conv2d(network, n_filter=32, filter_size=(5, 5), strides=(1, 1),
                act=None, b_init=None, padding='SAME', name='cnn_layer1')
        network = tl.layers.BatchNormLayer(network, act=tf.identity,#tf.nn.relu,
                # gamma_init=gamma_init,
                is_train=is_train, name='batch1')
        check = network.outputs
        network.outputs = tf.nn.relu(network.outputs)
        network = tl.layers.MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                padding='SAME', name='pool_layer1')

        network = tl.layers.Conv2d(network, n_filter=64, filter_size=(5, 5), strides=(1, 1),
                act=None, b_init=None, padding='SAME', name='cnn_layer2')
        network = tl.layers.BatchNormLayer(network, act=tf.nn.relu,
                # gamma_init=gamma_init,
                is_train=is_train, name='batch2')
        network = tl.layers.MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                padding='SAME', name='pool_layer2')

        ## end of conv
        network = tl.layers.FlattenLayer(network, name='flatten_layer')
        # if is_train:
        #     network = tl.layers.DropoutLayer(network, keep=0.5, is_fix=True, name='drop1')
        network = tl.layers.DenseLayer(network, n_units=256,
                                        act = tf.nn.relu, name='relu1')
        # if is_train:
        #     network = tl.layers.DropoutLayer(network, keep=0.5, is_fix=True, name='drop2')
        network = tl.layers.DenseLayer(network, n_units=10,
                                        act = tf.identity, name='output_layer')
    return network, check


# train phase
network, check = inference(x, is_train=True, reuse=False)
y = network.outputs
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y, y_))
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# test phase
network_test, check_t = inference(x, is_train=False, reuse=True)
y_t = network_test.outputs
cost_t = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y_t, y_))
correct_prediction_t = tf.equal(tf.argmax(y_t, 1), y_)
acc_t = tf.reduce_mean(tf.cast(correct_prediction_t, tf.float32))

# train
n_epoch = 5
learning_rate = 0.0001
print_freq = 1

train_params = network.all_params
train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
    epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)

tl.layers.initialize_global_variables(sess)

if is_test_only:
    load_params = tl.files.load_npz(name='_model_test.npz')
    tl.files.assign_params(sess, load_params, network)

# network.print_params(True)
network.print_layers()

# tl.layers.print_all_variables(train_only=True)

for i, p in enumerate(tf.all_variables()):
    print("  Before {:3}: {:15} (mean: {:<18}, median: {:<18}, std: {:<18})   {}".format(i, str(p.eval().shape), p.eval().mean(), np.median(p.eval()), p.eval().std(), p.name))

print('   learning_rate: %f' % learning_rate)
print('   batch_size: %d' % batch_size)

if not is_test_only:
    for epoch in range(n_epoch):
        start_time = time.time()
        for X_train_a, y_train_a in tl.iterate.minibatches(
                                    X_train, y_train, batch_size, shuffle=True):
            _, c = sess.run([train_op, check], feed_dict={x: X_train_a, y_: y_train_a})
            # print('bn out train:', np.mean(c), np.std(c))

        if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
            print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
            train_loss, train_acc, n_batch = 0, 0, 0
            for X_train_a, y_train_a in tl.iterate.minibatches(
                                    X_train, y_train, batch_size, shuffle=True):
                err, ac = sess.run([cost_t, acc_t], feed_dict={x: X_train_a, y_: y_train_a})
                train_loss += err; train_acc += ac; n_batch += 1
            print("   train loss: %f" % (train_loss/ n_batch))
            print("   train acc: %f" % (train_acc/ n_batch))
            val_loss, val_acc, n_batch = 0, 0, 0
            for X_val_a, y_val_a in tl.iterate.minibatches(
                                        X_val, y_val, batch_size, shuffle=True):
                err, ac = sess.run([cost_t, acc_t], feed_dict={x: X_val_a, y_: y_val_a})
                val_loss += err; val_acc += ac; n_batch += 1
            print("   val loss: %f" % (val_loss/ n_batch))
            print("   val acc: %f" % (val_acc/ n_batch))

print('Evaluation')
test_loss, test_acc, n_batch = 0, 0, 0
for X_test_a, y_test_a in tl.iterate.minibatches(
                            X_test, y_test, batch_size=1, shuffle=True):
    err, ac, c = sess.run([cost_t, acc_t, check_t], feed_dict={x: X_test_a, y_: y_test_a})
    # print('bn out test:', np.mean(c), np.std(c))
    test_loss += err; test_acc += ac; n_batch += 1
print("   test loss: %f" % (test_loss/n_batch))
print("   test acc: %f" % (test_acc/n_batch))

# network.print_param#! /usr/bin/python
# -*- coding: utf8 -*-

import numpy as np
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import set_keep
import time

is_test_only = False # if True, restore and test without training

X_train, y_train, X_val, y_val, X_test, y_test = \
                tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))

X_train = np.asarray(X_train, dtype=np.float32)[0:10000]#<-- small training set for fast debugging
y_train = np.asarray(y_train, dtype=np.int64)[0:10000]
X_val = np.asarray(X_val, dtype=np.float32)
y_val = np.asarray(y_val, dtype=np.int64)
X_test = np.asarray(X_test, dtype=np.float32)
y_test = np.asarray(y_test, dtype=np.int64)

print('X_train.shape', X_train.shape)
print('y_train.shape', y_train.shape)
print('X_val.shape', X_val.shape)
print('y_val.shape', y_val.shape)
print('X_test.shape', X_test.shape)
print('y_test.shape', y_test.shape)
print('X %s   y %s' % (X_test.dtype, y_test.dtype))

sess = tf.InteractiveSession()

batch_size = 128
x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
y_ = tf.placeholder(tf.int64, shape=[None,])

def inference(x, is_train, reuse=None):
    # gamma_init = tf.random_normal_initializer(1., 0.02)
    with tf.variable_scope("CNN", reuse=reuse):
        tl.layers.set_name_reuse(reuse)
        network = tl.layers.InputLayer(x, name='input_layer')

        network = tl.layers.Conv2d(network, n_filter=32, filter_size=(5, 5), strides=(1, 1),
                act=None, b_init=None, padding='SAME', name='cnn_layer1')
        network = tl.layers.BatchNormLayer(network, act=tf.identity,#tf.nn.relu,
                # gamma_init=gamma_init,
                is_train=is_train, name='batch1')
        check = network.outputs
        network.outputs = tf.nn.relu(network.outputs)
        network = tl.layers.MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                padding='SAME', name='pool_layer1')

        network = tl.layers.Conv2d(network, n_filter=64, filter_size=(5, 5), strides=(1, 1),
                act=None, b_init=None, padding='SAME', name='cnn_layer2')
        network = tl.layers.BatchNormLayer(network, act=tf.nn.relu,
                # gamma_init=gamma_init,
                is_train=is_train, name='batch2')
        network = tl.layers.MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                padding='SAME', name='pool_layer2')

        ## end of conv
        network = tl.layers.FlattenLayer(network, name='flatten_layer')
        # if is_train:
        #     network = tl.layers.DropoutLayer(network, keep=0.5, is_fix=True, name='drop1')
        network = tl.layers.DenseLayer(network, n_units=256,
                                        act = tf.nn.relu, name='relu1')
        # if is_train:
        #     network = tl.layers.DropoutLayer(network, keep=0.5, is_fix=True, name='drop2')
        network = tl.layers.DenseLayer(network, n_units=10,
                                        act = tf.identity, name='output_layer')
    return network, check


# train phase
network, check = inference(x, is_train=True, reuse=False)
y = network.outputs
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y, y_))
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# test phase
network_test, check_t = inference(x, is_train=False, reuse=True)
y_t = network_test.outputs
cost_t = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y_t, y_))
correct_prediction_t = tf.equal(tf.argmax(y_t, 1), y_)
acc_t = tf.reduce_mean(tf.cast(correct_prediction_t, tf.float32))

# train
n_epoch = 5
learning_rate = 0.0001
print_freq = 1

train_params = network.all_params
train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
    epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)

tl.layers.initialize_global_variables(sess)

if is_test_only:
    load_params = tl.files.load_npz(name='_model_test.npz')
    tl.files.assign_params(sess, load_params, network)

# network.print_params(True)
network.print_layers()

# tl.layers.print_all_variables(train_only=True)

for i, p in enumerate(tf.all_variables()):
    print("  Before {:3}: {:15} (mean: {:<18}, median: {:<18}, std: {:<18})   {}".format(i, str(p.eval().shape), p.eval().mean(), np.median(p.eval()), p.eval().std(), p.name))

print('   learning_rate: %f' % learning_rate)
print('   batch_size: %d' % batch_size)

if not is_test_only:
    for epoch in range(n_epoch):
        start_time = time.time()
        for X_train_a, y_train_a in tl.iterate.minibatches(
                                    X_train, y_train, batch_size, shuffle=True):
            _, c = sess.run([train_op, check], feed_dict={x: X_train_a, y_: y_train_a})
            # print('bn out train:', np.mean(c), np.std(c))

        if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
            print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
            train_loss, train_acc, n_batch = 0, 0, 0
            for X_train_a, y_train_a in tl.iterate.minibatches(
                                    X_train, y_train, batch_size, shuffle=True):
                err, ac = sess.run([cost_t, acc_t], feed_dict={x: X_train_a, y_: y_train_a})
                train_loss += err; train_acc += ac; n_batch += 1
            print("   train loss: %f" % (train_loss/ n_batch))
            print("   train acc: %f" % (train_acc/ n_batch))
            val_loss, val_acc, n_batch = 0, 0, 0
            for X_val_a, y_val_a in tl.iterate.minibatches(
                                        X_val, y_val, batch_size, shuffle=True):
                err, ac = sess.run([cost_t, acc_t], feed_dict={x: X_val_a, y_: y_val_a})
                val_loss += err; val_acc += ac; n_batch += 1
            print("   val loss: %f" % (val_loss/ n_batch))
            print("   val acc: %f" % (val_acc/ n_batch))

print('Evaluation')
test_loss, test_acc, n_batch = 0, 0, 0
for X_test_a, y_test_a in tl.iterate.minibatches(
                            X_test, y_test, batch_size=1, shuffle=True):
    err, ac, c = sess.run([cost_t, acc_t, check_t], feed_dict={x: X_test_a, y_: y_test_a})
    # print('bn out test:', np.mean(c), np.std(c))
    test_loss += err; test_acc += ac; n_batch += 1
print("   test loss: %f" % (test_loss/n_batch))
print("   test acc: %f" % (test_acc/n_batch))

# network.print_params(True)

if not is_test_only:
    tl.files.save_npz(network.all_params, name='_model_test.npz', sess=sess)

for i, p in enumerate(tf.all_variables()):
    print("  After {:3}: {:15} (mean: {:<18}, median: {:<18}, std: {:<18})   {}".format(i, str(p.eval().shape), p.eval().mean(), np.median(p.eval()), p.eval().std(), p.name))
s(True)

if not is_test_only:
    tl.files.save_npz(network.all_params, name='_model_test.npz', sess=sess)

for i, p in enumerate(tf.all_variables()):
    print("  After {:3}: {:15} (mean: {:<18}, median: {:<18}, std: {:<18})   {}".format(i, str(p.eval().shape), p.eval().mean(), np.median(p.eval()), p.eval().std(), p.name))

zsdonghao pushed a commit that referenced this issue May 4, 2019
BiRNN refactored, tested and doc updated, other minor updates to rnn, fixed dropout bug, make format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants