Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimicrobatch size error, only function when the minibatchsize = = 1 #17

Closed
MADONOKOUKI opened this issue Feb 15, 2019 · 10 comments
Closed
Assignees

Comments

@MADONOKOUKI
Copy link

MADONOKOUKI commented Feb 15, 2019

I implemented the other neural network model and take loss function by dp_optimizer.DPGradientDescentGaussianOptimizer.

In that time, I successed when the num_microbatch is 1. But when the num_microbatch is over 1, I got an error.

Traceback (most recent call last):
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 686, in _call_cpp_shape_fn_impl
    input_tensors_as_shapes, status)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension size must be evenly divisible by 2 but is 1 for 'Reshape' (op: 'Reshape') with input shapes: [], [2] and with input tensors computed as partial shapes: input[1] = [2,?].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "autoencoder_dp.py", line 70, in <module>
    population_size=60000).minimize(cost)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 399, in minimize
    grad_loss=grad_loss)
  File "/home/madono/madono/test2/dpgan/privacy/optimizers/dp_optimizer.py", line 68, in compute_gradients
    microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 5782, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3292, in create_op
    compute_device=compute_device)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3332, in _create_op_helper
    set_shapes_for_outputs(op)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2496, in set_shapes_for_outputs
    return _set_shapes_for_outputs(op)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2469, in _set_shapes_for_outputs
    shapes = shape_func(op)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2399, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 627, in call_cpp_shape_fn
    require_shape_fn)
  File "/home/madono/.pyenv/versions/anaconda3-2018.12/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 691, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Dimension size must be evenly divisible by 2 but is 1 for 'Reshape' (op: 'Reshape') with input shapes: [], [2] and with input tensors computed as partial shapes: input[1] = [2,?].

I implement like

optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer(
          l2_norm_clip=1.0,
          noise_multiplier=1.1,
          num_microbatches=2,
          learning_rate=0.0002,
population_size=60000).minimize(cost)
#optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)

# Initializing the variables
init = tf.initialize_all_variables()

# Explore trainable variables (weight_bias)
var = [v for v in tf.trainable_variables() if 'mimiciii/fc/autoencoder' in v.name] # (784, 128), (128,), (128, 784), (784,)
var_grad = tf.gradients(cost, var) # gradient of cost w.r.t. trainable variables, len(var_grad): 8, type(var_grad): list
norm_gradient_variables = []

# Launch the graph
with tf.Session() as sess:
    writer = tf.summary.FileWriter("./graph/my_graph", sess.graph)
    sess.run(init)
    total_batch = int(mnist.train.num_examples/batch_size)
    # Training cycle
    for epoch in range(training_epochs):
        # Loop over all batches
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
            var_grad_val = sess.run(var_grad, feed_dict={X: batch_xs})

            # var_grad_val = [var_grad_val[0], var_grad_val[2]] # no bias, change for different network
            if type(var_grad_val) != type([0]):  # if a is not a list, which indicate it contains only one weight matrix
                var_grad_val = [var_grad_val]
            norm_gradient_variables.append(norm_w(var_grad_val))  # compute the norm of all trainable variables
        # Display logs per epoch step
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch+1),
                  "cost=", "{:.9f}".format(c))
@schien1729
Copy link
Collaborator

Thanks for bringing this up, MADONOKOUKI. Can I ask what the value of batch_size is?

@MADONOKOUKI
Copy link
Author

@schien1729
Thank you for replying.
batch size is 256

@schien1729
Copy link
Collaborator

It seems that the code thinks that the loss you're passing (the variable cost) has length 1, and therefore can't split it into two microbatches. cost should be a vector of length 256 (your batch size). Is it possible that you've turned it into a scalar, perhaps by using a reduce function?

@MADONOKOUKI
Copy link
Author

I took mean squared error due to minimizing the error between input and output.
So I wrote

ae_net = Autoencoder(inputDim, l2scale, compressDims, aeActivation, decompressDims,
                           dataType)  # autoencoder network
clipnorm = 5.0
standard_deviation = 0.0001


# tf Graph input (only pictures)
X = tf.placeholder("float", [None, inputDim])

# Construct model
loss, latent, output  = ae_net(X)
print(output.shape)
# Prediction
y_pred = output
# Targets (Labels) are the input data.
y_true = X

cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
print(y_pred)
print(y_true)
cost = tf.losses.mean_squared_error(labels=y_true,
    predictions=y_pred,
)
# Calculate loss as a vector (to support microbatches in DP-SGD).
#vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
#labels=y_true, logits=y_pred)

#cost =  tf.reduce_mean(vector_loss)
# optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) # in medgan
      # Use DP version of GradientDescentOptimizer. For illustration purposes,
      # we do that here by calling optimizer_from_args() explicitly, though DP
      # versions of standard optimizers are available in dp_optimizer.
optimizer = dp_optimizer.DPGradientDescentGaussianOptimizer(
          l2_norm_clip=1.0,
          noise_multiplier=1.1,
          num_microbatches=256,
          learning_rate=.15,
          population_size=60000).minimize(cost)

But is this cost function wrong? In this time, which cost function in tensorflow should I use?

@schien1729
Copy link
Collaborator

The issue might be that by default, tf.losses.mean_squared_error aggregates the losses over all the examples it's given. Can you try this instead?

cost = tf.losses.mean_squared_error(labels=y_true,
    predictions=y_pred, reduction=Reduction.NONE
)

@MADONOKOUKI
Copy link
Author

Thanks!!!!

cost = tf.losses.mean_squared_error(labels=y_true,
    predictions=y_pred, reduction="none"
)

I can train my code by this function. I can`t import Reduction, So I directlt write "None"

@amanoel
Copy link

amanoel commented Mar 8, 2019

Hi, I'm having this same error on the MNIST Keras example shipped with the package

ValueError: Dimension size must be evenly divisible by 250 but is 1 for 'training/TFOptimizer/Reshape' (op: 'Reshape') with input shapes: [], [2] and with input tensors computed as partial shapes: input[1] = [250,?].

but only in Tensorflow 1.13, in 1.12 it works fine. Any ideas? Thanks!

@npapernot
Copy link
Collaborator

Have you made the one liner modification documented at the top of the MNIST keras example ?

@amanoel
Copy link

amanoel commented Mar 9, 2019

Sorry, completely missed it :) I just copied the relevant code somewhere else, so I ended up not looking at the docstring.

So I guess this is just #21, never mind my comment!

@Viserion-nlper
Copy link

I also met the same problem
How to solve this problem, please

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants