Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Leak Running simple feed_dict graph #9091

Closed
JonathanRaiman opened this issue Apr 9, 2017 · 17 comments
Closed

Memory Leak Running simple feed_dict graph #9091

JonathanRaiman opened this issue Apr 9, 2017 · 17 comments
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower

Comments

@JonathanRaiman
Copy link
Contributor

JonathanRaiman commented Apr 9, 2017

In a series simple tensorflow programs I obtain memory leaks (unbounded growth of CPU memory).
On original program on a computer with 64GB of RAM this leak is about 640 megabytes per hour (1% of total memory).

Plots of computer's memory over time:

Long time scale picture:

unknown-1

short time scale picture:

unknown

Problem description

The original program was more advanced and included RNNs/Saving/Loading etc.. but I "narrowed it down" to a simple for loop with no gradient descent where memory grows over time without bound.
Tested on Fedora 25 and Mac OSX 10.11.5. Issue occurs when running on single GPU (Titan X Pascal) or on CPU. Varying the sizes of the variables in the graph only changes the degree of growth, but does not prevent the effect from occurring. This issue occurs on tensorflow 0.12 and on current tensorflow 1.0.1. No custom code was used. Tensorflow was installed using pip in both cases (pre-compiled binary. Each time this was pip3 install tensorflow-gpu). Using CUDA 8.0, CuDNN v5 [though this should not impact the use-case, since no cudnn kernels are being used]. GPU is a Titan X Pascal 12GB of VRAM (not Titan Xp).

To reproduce:

import argparse
import psutil

from os import getpid
import tensorflow as tf
import numpy as np

def fc(inputs, output_size):
    with tf.variable_scope("FC"):
        input_size = inputs.get_shape()[-1].value
        W = tf.get_variable("W", shape=[input_size, output_size])
        b = tf.get_variable("b", shape=[output_size], initializer=tf.constant_initializer(0))
        out = tf.nn.xw_plus_b(inputs, W, b)
    return out

def create_model(input_size, output_size):
    # model placeholders:
    with tf.variable_scope("Inputs"):
        input_placeholder = tf.placeholder(
            tf.float32, [None, input_size], name="input_placeholder"
        )
    # meaningless function of inputs
    op = tf.reduce_mean(tf.reduce_sum(fc(input_placeholder, output_size), 1))
    return input_placeholder, op

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--max_epochs', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=7000)
    parser.add_argument('--input_size', type=int, default=100)
    parser.add_argument('--output_size', type=int, default=100)
    parser.add_argument('--device', type=str, default="gpu:0")
    return parser.parse_args(args=args)

def create_batches(inputs, input_size, batch_size, n):
    batches = []
    for i in range(n):
        X = np.random.uniform(-1.0, 1.0, size=(batch_size, input_size))
        batches.append({inputs: X})
    return batches

def main():
    args = parse_args()
    session_conf = tf.ConfigProto(allow_soft_placement=True)
    np.random.seed(1234)
    process = psutil.Process(getpid())

    with tf.Session(config=session_conf) as session, tf.device(args.device):
        inputs, op = create_model(args.input_size, args.output_size)
        session.run(tf.global_variables_initializer())
        batches = create_batches(inputs, args.input_size, args.batch_size, 20)

        for epoch in range(args.max_epochs):
            before = process.memory_percent()
            for feed_dict in batches:
                session.run(op, feed_dict)
            after = process.memory_percent()
            print("MEMORY CHANGE %.4f -> %.4f" % (before, after))

if __name__ == "__main__":
    main()

Output will be (exact numbers are percentages of computer's ram, so should change based on hardware, but main point is that memory continues to grow when the program has no variation between graph runs, batches are all the same size, no randomness is left in the program, etc.):

MEMORY CHANGE 1.2427 -> 1.3101
MEMORY CHANGE 1.3101 -> 1.3103
MEMORY CHANGE 1.3103 -> 1.3104
MEMORY CHANGE 1.3104 -> 1.3106
MEMORY CHANGE 1.3106 -> 1.3108
MEMORY CHANGE 1.3108 -> 1.3108
MEMORY CHANGE 1.3108 -> 1.3108
...
MEMORY CHANGE 1.3108 -> 1.3109
...
MEMORY CHANGE 1.3109 -> 1.3110
...

How can I fix this? I currently suspect a CPU memory pool issue inside tensorflow since the problem is fairly generic, and does not depend on the ops inside the graph (much). From what I've gathered most likely candidate is the tf.asarray/copying of numpy arrays in feed_dict, leading to memory fragmentation etc. Supposing this were the case, I've heard that tcmalloc should alleviate this, but no dice (note: I've also checked that objgraph shows no growth in program over time).

@jart
Copy link
Contributor

jart commented Apr 10, 2017

Thank you @JonathanRaiman for putting a lot of thought into communicating this issue. You've indicated you suspect there is a memory leak in the C++ code. In that case, @zhifengc might be able to advise you on how to troubleshoot down to the precise bug, or perhaps look into it himself.

@jart jart added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 10, 2017
@JonathanRaiman
Copy link
Contributor Author

@jart Terrific. Thanks!

@Panaetius
Copy link

Panaetius commented May 15, 2017

I have a similar issue and stumbled upon this report. I used the code supplied by @JonathanRaiman to quickly try and test what exactly (which instruction) is causing my issue.

After a lot of different tests, evaluating different ops, I got the following code that reliably reproduces this problem:

import argparse
import psutil

from os import getpid
import tensorflow as tf
import numpy as np

def create_model(input_size, output_size):
    # model placeholders:
    shape = tf.clip_by_value(tf.cast(tf.random_normal([2]) * 38.0 + 64.0, tf.int32), 38, 120)
    shape = tf.concat([[1], shape, [512]], axis=0)

    return tf.ones(shape, dtype=tf.int32)

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--max_epochs', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=7000)
    parser.add_argument('--input_size', type=int, default=100)
    parser.add_argument('--output_size', type=int, default=100)
    parser.add_argument('--device', type=str, default="gpu:0")
    return parser.parse_args(args=args)

def main():
    args = parse_args()
    session_conf = tf.ConfigProto(allow_soft_placement=True)
    np.random.seed(1234)
    process = psutil.Process(getpid())

    with tf.Session(config=session_conf) as session, tf.device(args.device):
        op = create_model(args.input_size, args.output_size)
        session.run(tf.global_variables_initializer())
        before = process.memory_percent()

        for epoch in range(args.max_epochs):
            session.run(op)
            
            if epoch % 100 == 0:
                after = process.memory_percent()
                print("MEMORY CHANGE %.4f -> %.4f" % (before, after))
                before = after

if __name__ == "__main__":
    main()

The tf.ones(shape, dtype=tf.int32) instruction causes the issue. Same with tf.zeros, tf.ones_like and tf.zeros_like. But the interesting part is, this ONLY happens with dtype=tf.int32, it doesn't happen for int64, int16, int8, uints, floats.

Another observation is that, while the reported memory usage by python on the first run is roughly the same for all those data types, the memory usage in the XFCE Task manager is more than twice as high for the int32 variant than for other datatypes. So it seems like python is incorrectly reporting memory usage when tf.int32 is used.

Examples (first bump is int64 with no growth, second bump is int32 with fast growth):

selection_001

Please also note that the memory usage increases rather quickly (from 2% memory to 8% memory in 10'000 interations, which takes about 10-15 seconds) and that having multiple tf.ones instruction makes it go up even faster, which can have a pretty noticeable effect on larger and more complex models.

But this only happens when the input-dimensions are random. If the shape supplied to tf.ones is the same on every run, memory used does not increased. So it only affects variable sized tensors.

Also, tf.cast(tf.ones(shape, dtype=tf.int64), tf.int32) works fine

I'm not 100% sure this is the same issue as @JonathanRaiman's, since he's not using int32 as far as I can tell, but his example does have a "None" dimension and the behaviour looks exactly the same.

And I'm using tensorflow built from master yesterday, though the problem also existed in 1.1, on Arch Linux

@JonathanRaiman
Copy link
Contributor Author

@Panaetius My original memory leak problem arised in code that had several int32 tensors being fed in with varying sizes (inputs for embedding lookup tables that have varying numbers of time steps going into RNNs). This sounds like it might be the same issue

@jstaker7
Copy link

jstaker7 commented Aug 1, 2017

This is becoming problematic for me as well. I also use non-fixed input dimensions.

Related:
#8560
https://stackoverflow.com/questions/42861956/gpu-poolallocator-explodes-the-cpu-memory

@JonathanRaiman
Copy link
Contributor Author

@jart @zhifengc Have these code snippets helped illuminate what might be the source for the leak? Is there any other information we could provide to help fix this?

@emsansone
Copy link

@JonathanRaiman I'm facing the same problem and I also suspect that it is due to the copying of numpy arrays in feed_dict

@dantkz
Copy link
Contributor

dantkz commented Nov 27, 2017

There is indeed different treatment for dtype=tf.int32 for tf.zeros, tf.zeros_like, tf.ones, tf.ones_like.

These ops are defined in tensorflow/python/ops/array_ops.py.

They all call fill function defined in tensorflow/core/kernels/constant_op.cc.

The FillOp is registered as usual for all types except int32:

// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Fill")
                            .Device(DEVICE_GPU)
                            .TypeConstraint<int32>("T")
                            .HostMemory("dims")
                            .HostMemory("value")
                            .HostMemory("output"),
                        FillOp<CPUDevice, int32>);

ConstantOp in the same file also has this special treatment.

*REGISTER_KERNEL_BUILDER macro is defined in tensorflow/core/framework/op_kernel.h.

@dantkz
Copy link
Contributor

dantkz commented Nov 27, 2017

There may be a relation to issue 13221.

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and the awaiting tensorflower label was assigned. Please update the label and/or status accordingly.

@jart
Copy link
Contributor

jart commented Dec 22, 2017

Thanks for pointing out that #13221 is likely a duplicate of this. That issue has gained a lot of attention from the development team, with information on what should happen, and we're hoping someone in the community will volunteer to be the one to solve it. Please follow the other issue from here out.

@JonathanRaiman
Copy link
Contributor Author

Appears to be fixed in tf 1.6

@hbb21st
Copy link

hbb21st commented Jun 6, 2018

No, not yet. I still found this problem in tf 1.8.

@galeone
Copy link

galeone commented Jun 15, 2018

I second @hbb21st , I'm facing the same (or very similar) issue in tensorflow 1.8, running a loop of sess.run passing via feed dict just a string with the filename to read:

for image in glob(os.path.join(validationset_path, "*.png")):
    embed = sess.run(latent, feed_dict={filename_: image})

After some iteration:

E tensorflow/stream_executor/cuda/cuda_driver.cc:967] failed to alloc 20288458752 bytes on host: CUDA_ERROR_OUT_OF_MEMORY
W ./tensorflow/core/common_runtime/gpu/pool_allocator.h:195] could not allocate pinned host memory of size: 20288458752

@Satcheel
Copy link

Satcheel commented Jul 6, 2018

I'm getting a similar problem. I'm generating batches manually from different .h5 files.

with sess.as_default():
	full_start = clock()
	for i in range(100):    #epochs
		start = clock()
		batch_train_size = 512
		batch_test_size = 200
		start_train_index =0
		end_train_index = start_train_index+batch_train_size
		start_test_index = 0
		end_test_index = start_test_index+batch_test_size
		for j in range(int(ceil(float(train_len)/batch_train_size))):		
			if start_train_index  >= train_len:
				start_train_index = 0
			if start_test_index >= test_len:
				start_test_index = 0
			end_train_index = start_train_index+batch_train_size
			end_test_index = start_test_index + batch_test_size
			if end_train_index >= train_len:
				end_train_index = train_len
			if end_test_index >= test_len:
				end_test_index = test_len
			print 'epoch:',i+1,'/100 batch_num:',j+1,'/19'
			x_train,y_train,x_test,y_test = loadData(start_train_index,end_train_index,start_test_index,end_test_index)
			start_train_index = end_train_index
			start_test_index = end_test_index
                        print x_train.shape,x_test.shape
			train_step.run(feed_dict={X_train: x_train,
	                                  labels: y_train})

loadData function returns padded input features of different videos from .h5 files

After a few batches

2018-07-06 18:04:56.278757: W ./tensorflow/core/common_runtime/gpu/pool_allocator.h:195] could not allocate pinned host memory of size: 7730940928
Killed

Can someone suggest a way to load batches manually and not exhaust memory
@JonathanRaiman @hbb21st Please guide me if you solved your error

@masadcv
Copy link

masadcv commented Apr 22, 2019

@Satcheel were you able to solve this issue? I am getting a similar memory leak when I change my feed_dict to extract specific tensor values in a single session. Any help in this regards would be really helpful! Many thanks!

@Satcheel
Copy link

Sorry, I am not very active on GitHub.

Leaving the answer here for future reference.
My problem was not memory leak but my huge training data (250GB). So, used keras and steps_for_epoch for training and used a generator function to load data. I was trying to do the same in native tf but couldn't work it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

No branches or pull requests