Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential memory leak when using LSTM + TimeDistributed #33178

Closed
algrmur opened this issue Oct 9, 2019 · 32 comments
Closed

Potential memory leak when using LSTM + TimeDistributed #33178

algrmur opened this issue Oct 9, 2019 · 32 comments
Assignees
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@algrmur
Copy link

algrmur commented Oct 9, 2019


System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution: Windows 10
  • TensorFlow installed from: binary
  • TensorFlow version: v2.0.0-rc2-26-g64c3d382ca 2.0.0
  • Python version: 3.6.9
  • Bazel version (if compiling from source): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: 10.0 (CUDA) / 7.5 (cuDNN)
  • GPU model and memory: TITAN RTX (24GB)
  • Exact command to reproduce: N/A

Describe the problem

Potential memory leak when using LSTM + TimeDistributed

I have a standard time series model that consists 3 layers of convolutional layers feeding into 2 LSTM layers. Up until now, I have had no problems mapping a Dense layer to the last output of the top LSTM and making a prediction etc. However, I want to implement a model where I use a TimeDistributed(Dense(..)) layer on top of the top LSTM and feed back the error signal at each time point. I have implemented this but after only training a few epochs, I get a resource exhausted error.

It doesn't seem to be affected by how small I make the model, after training for a few epochs. The error I get is: "ResourceExhaustedError: OOM when allocating tensor with shape[25600,9,11,128]". This comes after a call to tape.gradients (full error reported in section below).

In my non-TimeDistributed I monitor the number of objects via len(gc.get_objects())) and during training the object count remains the same (as expected), but when I only change the model to handle this TimeDistributed change (i.e. making sure the labels are correctly repeated and making return_sequences=1 for the top-level LSTM) then all of a sudden at each training epoch, thousands of new variables are being added during each epoch of training.

gc objects: 249861
[TRAIN]: End (epoch 0): loss 0.693269372 ; train accuracy 0.5
[TEST]: End (epoch 0): loss 0.691318274 ; test accuracy 0.500683606
gc objects: 251746 (+ 1885 objects)
[TRAIN]: End (epoch 1): loss 0.691800237 ; train accuracy 0.500202894
[TEST]: End (epoch 1): loss 0.690349817 ; test accuracy 0.502343774
gc objects: 254144 (+ 2398 objects)
[TRAIN]: End (epoch 2): loss 0.690762699 ; train accuracy 0.500456572
[TEST]: End (epoch 2): loss 0.689480364 ; test accuracy 0.504296899
gc objects: 254996 (+852 objects)
[TRAIN]: End (epoch 3): loss 0.692312837 ; train accuracy 0.501090705
[TEST]: End (epoch 3): loss 0.689140499 ; test accuracy 0.505468726
gc objects: 269643 (+ 14647 objects)
[TRAIN]: End (epoch 4): loss 0.688487 ; train accuracy 0.501116097
[TEST]: End (epoch 4): loss 0.686942577 ; test accuracy 0.508886695
gc objects: 270444 (+ 801 objects)

So in 4 epochs of training, while no other process is running, 20,583 new objects were created and I presume resulted in this resource exhausted error.

I've tried to force the garbage collector to collect any unused variables but the object count increases whether this is included or not. I ran a snapshot comparison from the tracemalloc library, which I will include below as it might be helpful (it wasn't to me).

Something is creating variables during every epoch, vastly using up all the memory and not releasing them, leading to this resource exhausted error. This doesn't occur if I don't use TimeDistributed, so I don't think anything about this layer requires the creation of additional memory-hungry variables. It looks more like a leak.

Do you have any idea of what I could do to alleviate this problem? It seems like a bug fix at a technical level. Maybe there is a technical solution. Please let me know if any further info from my end would be useful in looking at this issue.

Source code / logs

tracemalloc top 10 differences between snapshot calls at adjacent epochs

C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\execute.py:61: size=111 KiB (+69.9 KiB), count=677 (+426), average=168 B
:14: size=7464 B (-46.9 KiB), count=107 (-749), average=70 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\tokenize.py:609: size=2944 B (-43.6 KiB), count=46 (-698), average=64 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py:193: size=59.9 KiB (+33.8 KiB), count=1305 (+732), average=47 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\training\tracking\data_structures.py:768: size=54.0 KiB (+31.3 KiB), count=386 (+219), average=143 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py:718: size=55.7 KiB (+30.8 KiB), count=1018 (+564), average=56 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py:776: size=51.0 KiB (+28.7 KiB), count=1235 (+690), average=42 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\keras\utils\generic_utils.py:564: size=40.9 KiB (+25.8 KiB), count=675 (+426), average=62 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\framework\ops.py:1035: size=39.3 KiB (+23.3 KiB), count=950 (+566), average=42 B
C:\Users\AXM1390\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\training\tracking\data_structures.py:809: size=27.1 KiB (+15.9 KiB), count=3 (+0), average=9248 B

full error

ResourceExhaustedError Traceback (most recent call last)
in
----> 1 best_val, best_epoch, tmp_history = run_model_once(0, 25, epochs=50)

in run_model_once(start, end, epochs)
36 printed_cm = False
37
---> 38 train_loss, val_loss, acc_metric, val_acc_metric = train(RCNN_model, optimizer, train_ds, test_ds, cm)
39 tf.print(f'[TRAIN]: End (epoch {i}): loss', train_loss, '; train accuracy', acc_metric.result())
40 tf.print(f'[TEST]: End (epoch {i}): loss', val_loss, '; test accuracy', val_acc_metric.result())

in train(model, optimizer, train_ds, test_ds, cm)
60 for x_true, y_true in train_ds:
61 if TIME_DISTRIBUTED:
---> 62 train_loss = train_one_step_timedistributed(model, optimizer, x_true, y_true, training=True)
63 else:
64 train_loss = train_one_step(model, optimizer, x_true, y_true, training=True)

in train_one_step_timedistributed(model, optimizer, x_true, y_true, training)
22 print(f'model trainable variables: {len(model.trainable_variables)}')
23
---> 24 gradients = tape.gradient(loss_, model.trainable_variables)
25 optimizer.apply_gradients(zip(gradients, model.trainable_variables))
26

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
1012 output_gradients=output_gradients,
1013 sources_raw=flat_sources_raw,
-> 1014 unconnected_gradients=unconnected_gradients)
1015
1016 if not self._persistent:

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\imperative_grad.py in imperative_grad(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)
74 output_gradients,
75 sources_raw,
---> 76 compat.as_str(unconnected_gradients.value))

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\backprop.py in _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads, skip_input_indices)
136 return [None] * num_inputs
137
--> 138 return grad_fn(mock_op, *out_grads)
139
140

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\ops\math_grad.py in _TanhGrad(op, grad)
712 with ops.control_dependencies([grad]):
713 y = math_ops.conj(y)
--> 714 return gen_math_ops.tanh_grad(y, grad)
715
716

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py in tanh_grad(y, dy, name)
11410 else:
11411 message = e.message

11412 _six.raise_from(_core._status_to_exception(e.code, message), None)
11413 # Add nodes to the TensorFlow graph.
11414 _, _, _op = _op_def_lib._apply_op_helper(

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\six.py in raise_from(value, from_value)

ResourceExhaustedError: OOM when allocating tensor with shape[25600,9,11,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:TanhGrad]

@oanush oanush self-assigned this Oct 10, 2019
@oanush oanush added comp:apis Highlevel API related issues TF 2.0 Issues relating to TensorFlow 2.0 type:support Support issues labels Oct 10, 2019
@oanush
Copy link

oanush commented Oct 10, 2019

@algrmur ,
Can you share a simple and standalone code to reproduce the issue reported here? Thanks!

@oanush oanush added the stat:awaiting response Status - Awaiting response from author label Oct 10, 2019
@akanyaani
Copy link

Hi @algrmur
There are a couple of bugs in TF2.0, one possible reason could be variable sequence length. It would be great if you can share sample code which can produce the same error.

@algrmur
Copy link
Author

algrmur commented Oct 10, 2019

Hello @oanush and @akanyaani,

Thanks for your replies. I've stripped down all of the non-essential code in my program and set it up to use randomly generated data to recreate the problem. Please see below for (1) code and (2) output given for me.

Imports + Data Generation

import numpy as np
import tensorflow as tf 
from tensorflow.keras.layers import (Conv2D, LSTM, Dense, Dropout,
                                     Softmax, BatchNormalization, TimeDistributed)
from scipy.stats import mode
import gc

print(f'Using TensorFlow {tf.__version__}, GPU available? : {tf.test.is_gpu_available()}')

n_data = 5000
height = 9
width = 11
n_timepoints = 100
chan_dim = 1

train_x = np.random.rand(n_data, height, width, n_timepoints, chan_dim)
train_y = np.random.randint(low=0, high=3, size=n_data)
train_y = tf.keras.utils.to_categorical(train_y)

test_x = np.random.rand(n_data, height, width, n_timepoints, chan_dim)
test_y = np.random.randint(low=0, high=3, size=n_data)
test_y = tf.keras.utils.to_categorical(test_y)
                           
print(f'train_x shape: {train_x.shape}')
print(f'train_y shape: {train_y.shape}')
print(f'test_x shape: {test_x.shape}')
print(f'test_y shape: {test_y.shape}')

Model definition

class mymodel(tf.keras.Model):
    
    def __init__(self, n_filters, n_fc, n_output, n_batch, n_nodes, dropout):
        super(mymodel, self).__init__(name='RCNN')
        
        # Set model properties as instance attributes
        self.n_filters = n_filters
        self.n_fc = n_fc
        self.n_output = n_output
        self.N_BATCH = n_batch
        self.N_NODES = n_nodes
        self.DROPOUT = dropout
        self.out_activation = "sigmoid" if n_output == 2 else "softmax"
        
        self.conv1 = Conv2D(filters=n_filters, strides=1,padding='same', 
                            activation='tanh', kernel_size=3)
        
        self.conv2 = Conv2D(filters=n_filters*2, strides=1,padding='same', 
                            activation='tanh', kernel_size=3)
        
        self.conv3 = Conv2D(filters=n_filters*4, strides=1,padding='same', 
                            activation='tanh', kernel_size=3)
            
        self.dense1 = Dense(n_fc)
        self.dropout1 = Dropout(dropout) 
        
        self.lstm1 = LSTM(self.N_NODES, recurrent_initializer='orthogonal', return_sequences=1)
        self.lstm2 = LSTM(self.N_NODES, recurrent_initializer='orthogonal', return_sequences=1)
        
        self.fc2         = TimeDistributed(Dense(n_fc))
        self.fc2_dropout = TimeDistributed(Dropout(dropout))
        self.outputlayer = TimeDistributed(Dense(n_output, activation=self.out_activation))
        
    def call(self, inputs, training):
        
        _, height, width, n_timesteps, n_input = inputs.shape
        inputs = tf.reshape(inputs, [self.N_BATCH*n_timesteps, height, width, 1])
        
        conv1_ = self.conv1(inputs)
        conv2_ = self.conv2(conv1_)
        conv3_ = self.conv3(conv2_)
        
        flattened_ = tf.reshape(conv3_, [-1, conv3_.shape[1]*conv3_.shape[2]*conv3_.shape[3]])
        dense1_   = self.dense1(flattened_)
        dropout1_ = self.dropout1(dense1_, training=training)
        
        lstm_in_ = tf.reshape(dropout1_, [-1, n_timesteps, self.n_fc])
        lstm1_   = self.lstm1(lstm_in_)
        lstm2_   = self.lstm2(lstm1_)
        
        fc2_         = self.fc2(lstm2_)
        fc2_dropout_ = self.fc2_dropout(fc2_, training=training)
        output_      = self.outputlayer(fc2_dropout_)
        
        return output_

Loss/Optimiser/tf.Dataset/Model instantiation

optimizer = tf.keras.optimizers.Adam(1e-4)
loss_fn = tf.keras.losses.CategoricalCrossentropy()

train_ds = tf.data.Dataset.from_tensor_slices((train_x,
                                               train_y)).batch(256, drop_remainder=True)
test_ds = tf.data.Dataset.from_tensor_slices((test_x,
                                              test_y)).batch(256, drop_remainder=True)

my_model = mymodel(n_filters=32, n_fc=256, n_batch=256, dropout=0.5,
                   n_output=train_y.shape[1], n_nodes=256)

Training functions

def train_one_step(model, optimizer, x_true, y_true, training):
    with tf.GradientTape() as tape:
        y_pred = model(x_true, training)
        y_true_expanded = np.repeat(y_true, n_timepoints, axis=0).reshape((y_pred.shape[0], n_timepoints, -1))
        loss_ = loss_fn(y_true_expanded, y_pred)
        
    gradients = tape.gradient(loss_, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    return loss_

def test_one_step(model, optimizer, x_true, y_true, training):
    y_pred = model(x_true, training)
    y_true_expanded = np.repeat(y_true, n_timepoints, axis=0).reshape((y_pred.shape[0], n_timepoints, -1))
    loss_ = loss_fn(y_true_expanded, y_pred)
    return loss_

def train(model, optimizer, train_ds, test_ds):
    
    for x_true, y_true in train_ds:
        train_loss = train_one_step(model, optimizer, x_true, y_true, training=True) 
        
    for x_true, y_true in test_ds:
        val_loss = test_one_step(model, optimizer, x_true, y_true, training=False)
            
    return (train_loss, val_loss)

Running the model

for i in range(200):
    gc.collect()
    print(f'gc objects: {len(gc.get_objects())}')
    
    train_loss, val_loss = train(my_model, optimizer, train_ds, test_ds)
    tf.print(f'[TRAIN]: End (epoch {i}): loss', train_loss)
    tf.print(f'[TEST]:  End (epoch {i}): loss', val_loss)

The output of this shows that many objects are being created during the runs of the epochs.

Output

gc objects: 223324
[TRAIN]: End (epoch 0): loss 1.09557235
[TEST]:  End (epoch 0): loss 1.10084486
gc objects: 225044
[TRAIN]: End (epoch 1): loss 1.09603643
[TEST]:  End (epoch 1): loss 1.09944308
gc objects: 225427
[TRAIN]: End (epoch 2): loss 1.09861755
[TEST]:  End (epoch 2): loss 1.09861553
gc objects: 225814
[TRAIN]: End (epoch 3): loss 1.09743845
[TEST]:  End (epoch 3): loss 1.09851873
gc objects: 226235
[TRAIN]: End (epoch 4): loss 1.09701049
[TEST]:  End (epoch 4): loss 1.09853053
gc objects: 226551
[TRAIN]: End (epoch 5): loss 1.09691608
[TEST]:  End (epoch 5): loss 1.09853923
gc objects: 226912
---------------------------------------------------------------------------
ResourceExhaustedError                    Traceback (most recent call last)
<ipython-input-6-b4dc6b381cfe> in <module>
      3     print(f'gc objects: {len(gc.get_objects())}')
      4 
----> 5     train_loss, val_loss = train(my_model, optimizer, train_ds, test_ds)
      6     tf.print(f'[TRAIN]: End (epoch {i}): loss', train_loss)
      7     tf.print(f'[TEST]:  End (epoch {i}): loss', val_loss)

<ipython-input-5-5f343d76c3bf> in train(model, optimizer, train_ds, test_ds)
     21 
     22     for x_true, y_true in train_ds:
---> 23         train_loss = train_one_step(model, optimizer, x_true, y_true, training=True)
     24 
     25     for x_true, y_true in test_ds:

<ipython-input-5-5f343d76c3bf> in train_one_step(model, optimizer, x_true, y_true, training)
      5         loss_ = loss_fn(y_true_expanded, y_pred)
      6 
----> 7     gradients = tape.gradient(loss_, model.trainable_variables)
      8     optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      9 

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
   1012         output_gradients=output_gradients,
   1013         sources_raw=flat_sources_raw,
-> 1014         unconnected_gradients=unconnected_gradients)
   1015 
   1016     if not self._persistent:

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\imperative_grad.py in imperative_grad(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)
     74       output_gradients,
     75       sources_raw,
---> 76       compat.as_str(unconnected_gradients.value))

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\eager\backprop.py in _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads, skip_input_indices)
    136     return [None] * num_inputs
    137 
--> 138   return grad_fn(mock_op, *out_grads)
    139 
    140 

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\ops\math_grad.py in _TanhGrad(op, grad)
    712   with ops.control_dependencies([grad]):
    713     y = math_ops.conj(y)
--> 714     return gen_math_ops.tanh_grad(y, grad)
    715 
    716 

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py in tanh_grad(y, dy, name)
  11410       else:
  11411         message = e.message
> 11412       _six.raise_from(_core._status_to_exception(e.code, message), None)
  11413   # Add nodes to the TensorFlow graph.
  11414   _, _, _op = _op_def_lib._apply_op_helper(

~\AppData\Local\Continuum\anaconda3\envs\tf2\lib\site-packages\six.py in raise_from(value, from_value)

ResourceExhaustedError: OOM when allocating tensor with shape[25600,9,11,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:TanhGrad]

@akanyaani - Thanks for your comment about it potentially being related to variable length input, but in my case (and also in the case of the dummy code above) there is no variable length input as there are always 100 time points in the data I am using.

If you:

  • take out the 3 lines using "TimeDistributed" in the bottom of the model's "call" function
  • alter the call to the loss function to use "y_true" and not "y_true_expanded",
  • set the last LSTM's return_sequences parameter to 0

then the same code runs while keeping the amount of objects relatively stable, as seen below. This is what made me sure it was related to TimeDistributed.

Output without using TimeDistributed

[TRAIN]: End (epoch 0): loss 1.11176896
[TEST]:  End (epoch 0): loss 1.10250723
gc objects: 224310
[TRAIN]: End (epoch 1): loss 1.10939479
[TEST]:  End (epoch 1): loss 1.0968765
gc objects: 224340
[TRAIN]: End (epoch 2): loss 1.09383631
[TEST]:  End (epoch 2): loss 1.09696317
gc objects: 224338
[TRAIN]: End (epoch 3): loss 1.10116696
[TEST]:  End (epoch 3): loss 1.09682214
gc objects: 224340
[TRAIN]: End (epoch 4): loss 1.09012806
[TEST]:  End (epoch 4): loss 1.09699178
gc objects: 224310
[TRAIN]: End (epoch 5): loss 1.09775305
[TEST]:  End (epoch 5): loss 1.09818268
gc objects: 224310
[TRAIN]: End (epoch 6): loss 1.09012258
[TEST]:  End (epoch 6): loss 1.09953499
gc objects: 224342
[TRAIN]: End (epoch 7): loss 1.08458531
[TEST]:  End (epoch 7): loss 1.09981978
gc objects: 224340
[TRAIN]: End (epoch 8): loss 1.08050597
[TEST]:  End (epoch 8): loss 1.10207307
gc objects: 224310
[TRAIN]: End (epoch 9): loss 1.07563555
[TEST]:  End (epoch 9): loss 1.10376084
gc objects: 224340
[TRAIN]: End (epoch 10): loss 1.07361245
[TEST]:  End (epoch 10): loss 1.10961676
gc objects: 224310
[TRAIN]: End (epoch 11): loss 1.05968428
[TEST]:  End (epoch 11): loss 1.11357927
gc objects: 224310
[TRAIN]: End (epoch 12): loss 1.05215323
[TEST]:  End (epoch 12): loss 1.11662757
gc objects: 224310
....

I hope you can also recreate the issue and thereby potentially see where the problem might lie.
Any advice / solutions would be greatly appreciated!

Kind regards
Alex

@akanyaani
Copy link

Hi @algrmur
Screenshot 2019-10-10 at 5 16 14 PM (2)

It works fine on my system could you please try with smaller batch size.

@algrmur
Copy link
Author

algrmur commented Oct 10, 2019

Hi @akanyaani,

Interesting! I tried via notebooks, command line etc. and it always gave the same error. I will try it on my Linux laptop to see if it also breaks there. Do you mind if I ask about your specs so I can see what else I might be able to try (mainly just CUDA / cuDNN version and what version of Python you used)? Then I can try on Windows with the same things running as you, as that might fix my problem.

Once I am back at my main workstation tomorrow, I will try with a lower batch size to answer your question as to whether that might be the problem. I didn't think it would be because TimeDistributed(Dense(..)) uses the same weights for each time step so I thought the computation would be equivalent (in terms of the gradient call) to the non-TimeDistributed case (that does work). I could be wrong though. Furthermore, I don't know why it would be fine with the first 5-6 epochs and then fail afterwards. If it can handle the first few, nothing new should be done during the training so I still have no explanation as to how the OOM could occur.

More information tomorrow!
Thanks again!

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Oct 11, 2019
@Tetragramm
Copy link
Contributor

Tetragramm commented Oct 14, 2019

I am also seeing a memory leak. No LSTM though, just TimeDistributed. This model fails after printing 11 on a 2080Ti. Batch Size is always 1, so that's not the problem.
EDIT: Shrunk minimal example by quite a bit. Is now 1 Conv1D layer, and 1 TimeDistributed Dense layer.

from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
import numpy as np
import tensorflow

def BuildGenerator():
    i = Input(shape=(None,2,))
    
    n_input = 12*21
    to_n = Input(shape=(n_input))
    s_n = Dense(12*21, activation='softmax')(to_n)
    s_n = Reshape((12,21))(s_n)
    n_base = Model(inputs=[to_n], outputs=[s_n])
    
    b = Conv1D(n_input, 11, dilation_rate=1, padding='same', activation='relu', data_format='channels_last')(i)
    n = TimeDistributed(n_base)(b)

    return Model(inputs=[i], outputs=[n])

def InputGenerator():
    for iter in range(1000):
        print(iter)
        i = np.zeros((1,10*60*1000,2))
        n = np.zeros((1,10*60*1000,12,21))
        yield ([i], [n])

with tensorflow.device('/device:gpu:0'):
    
    m2t = BuildGenerator()
    
    m2t.compile(optimizer='adam', loss='mse')
    
    for epoch in range(1):
        for inout in InputGenerator():
            m2t.train_on_batch(inout[0], inout[1])

@Tetragramm
Copy link
Contributor

So, does this line do what I think it does?

self._input_map[input_uid] = inputs

Because that looks like it's permanently storing the inputs in a map that never gets cleared, using a global UUID as the key.

@oanush
Copy link

oanush commented Oct 15, 2019

Hi @akanyaani,

Interesting! I tried via notebooks, command line etc. and it always gave the same error. I will try it on my Linux laptop to see if it also breaks there. Do you mind if I ask about your specs so I can see what else I might be able to try (mainly just CUDA / cuDNN version and what version of Python you used)? Then I can try on Windows with the same things running as you, as that might fix my problem.

Once I am back at my main workstation tomorrow, I will try with a lower batch size to answer your question as to whether that might be the problem. I didn't think it would be because TimeDistributed(Dense(..)) uses the same weights for each time step so I thought the computation would be equivalent (in terms of the gradient call) to the non-TimeDistributed case (that does work). I could be wrong though. Furthermore, I don't know why it would be fine with the first 5-6 epochs and then fail afterwards. If it can handle the first few, nothing new should be done during the training so I still have no explanation as to how the OOM could occur.

More information tomorrow!
Thanks again!

@algrmur ,
Any update on the issue ? Thanks!

@oanush oanush added the stat:awaiting response Status - Awaiting response from author label Oct 15, 2019
@algrmur
Copy link
Author

algrmur commented Oct 15, 2019

Hi @akanyaani,
Interesting! I tried via notebooks, command line etc. and it always gave the same error. I will try it on my Linux laptop to see if it also breaks there. Do you mind if I ask about your specs so I can see what else I might be able to try (mainly just CUDA / cuDNN version and what version of Python you used)? Then I can try on Windows with the same things running as you, as that might fix my problem.
Once I am back at my main workstation tomorrow, I will try with a lower batch size to answer your question as to whether that might be the problem. I didn't think it would be because TimeDistributed(Dense(..)) uses the same weights for each time step so I thought the computation would be equivalent (in terms of the gradient call) to the non-TimeDistributed case (that does work). I could be wrong though. Furthermore, I don't know why it would be fine with the first 5-6 epochs and then fail afterwards. If it can handle the first few, nothing new should be done during the training so I still have no explanation as to how the OOM could occur.
More information tomorrow!
Thanks again!

@algrmur ,
Any update on the issue ? Thanks!

Hi oanush,

I tried running a reduced model on a very small batch size (16) and it ran longer than it did last time but there was still a considerable increase of objects in memory on each training loop (a few hundred at each iteration). It just made a bit more space for more epochs to run, but then ran into OOM errors at a later point.

I think Tetragramm (above) is having the same issue and further confirms my belief it's with the TimeDistributed layer (as without it - my model runs fine). I wanted to run the same model using the same setup as akanyaani but so far I've not seen what the exact details were in his case / whether he has much more memory available.

@arthurflor23
Copy link

I had this issue in version 2.0.0.

Beta1 version is working and running faster per epoch.

@Tetragramm
Copy link
Contributor

@arthurflor23 Are you sure? The Beta1 code is nearly identical to the release 2.0.0 code. The only change is to regularization, which has nothing that would do a memory leak.

Upon further review (actually reading through the code), I'm pretty sure the _input_map variable is both the cause, and useless. I think lines 56, 246, 247, 249, 308, and 309 can be removed, and line 310 replaced with

output_mask = self.layer.compute_mask(inputs, inner_mask)

Unfortunately, I'm having trouble building tensorflow from source to test.

@arthurflor23
Copy link

arthurflor23 commented Oct 15, 2019

Hi!
I was training a model today (via google colab) using tensorflow-gpu==2.0.0.. I've added two TimeDistributed layer and I realized that the time of the epochs was increasing.. started with ~200s in first epoch and stopped with ~350s in the last.. then, the issue mentioned happened.

I don't know if it's related to this or another module version, but Beta1 doesn't happen and make ~140s per epoch..

I will do more tests with the two models that I'm studying, cause I already have another problem in the recurrent layers and ThenRnnBackward

Just to complement, this behavior that I mentioned appears since rc0. (I installed version by version to check, via google colab)

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Oct 16, 2019
Tetragramm added a commit to Tetragramm/tensorflow that referenced this issue Oct 16, 2019
@oanush oanush added comp:keras Keras related issues type:bug Bug and removed comp:apis Highlevel API related issues type:support Support issues labels Oct 17, 2019
@oanush oanush assigned rmothukuru and unassigned oanush Oct 18, 2019
@rmothukuru rmothukuru added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 18, 2019
@sonfire186
Copy link

two weeks have passed since success?

@Tetragramm
Copy link
Contributor

Not quite. Some complications. Awaiting someone with better understanding of the system than me.

@astier
Copy link

astier commented Oct 26, 2019

I also have a memory leak since 1.14 up to 2.0. On 1.13 the leak disappears.

@sonfire186
Copy link

I`m used TF 1.14 and not have a memory leak

@arnemoos
Copy link

arnemoos commented Nov 2, 2019

Hi,

I'm using TF-GPU 2.0.0 and having the same issue when using the TimeDistributed wrapper...
Adding a

self._input_map.clear()

before this:

self._input_map[input_uid] = inputs

does not result in an increasing gpu memory allocation... But I don't know if it is now still correct? I only saw self._input_map be called at the preparation of the training and only referring to the last input_uid. So I thought that clearing before adding the latest element would not destroy the logic behind it, but still fix the memory leakage.

@dgrahn
Copy link

dgrahn commented Nov 8, 2019

I can verify that @arnemoos 's workaround prevents the OOM for me.

@arthurflor23
Copy link

arthurflor23 commented Nov 17, 2019

I'm running tf-nightly-gpu today and I have no more error, can anyone confirm?

@uahic
Copy link

uahic commented Nov 26, 2019

I can confirm that using TimeDistributed also runs my model into resource allocation errors for tf 2.0.0. Using the fit_generator() training function with a model that has 3x 2DConv layers, each wrapped in TimeDistributed on a batch of 39,32 MB total memory footprint (batch size=32).

@arthurflor23 Will try tf-nightly-gpu now and confirm/not-confirm

@uahic
Copy link

uahic commented Nov 26, 2019

@arthurflor23 I can confirm that the issue has been gone for me as well :)

@sonfire186
Copy link

@arthurflor23, yes it`s try

@johnght
Copy link

johnght commented Dec 10, 2019

Again, confirm that TimeDistributed is the culprit. In my case, tf-nightly breaks my model. Solved the problem by writing a custom for-loop in subclassed model rather than using TimeDistributed. But this bug has to be fixed for those using non-subclassed model.

@kohjingyu
Copy link

I was able to work around this issue by reshaping my tensor to combine the first two dimensions, applying the convolution / dense layer, and reshaping back to the expected output shape.

@Tetragramm
Copy link
Contributor

The fix has been merged.
#33441 (comment)

@goldiegadde
Copy link
Contributor

Thanks @Tetragramm . @algrmur please let us know if your issue has been fixed and we can close this issue.

tensorflow-copybara pushed a commit that referenced this issue Jan 24, 2020
Imported from GitHub PR #33441

Documented in [THIS](#33178) thread.
Based on the documentation of get() and generic_utils.object_list_uid, this has no functional effect, except to remove an unnecessary map that was growing with every input.

Tested using the example program in [THIS](#331...

PiperOrigin-RevId: 291277510
Change-Id: I97df3c26850ae460d41e5032bb71edd11c948670
@qlzh727
Copy link
Member

qlzh727 commented Jan 28, 2020

The PR has been rollback due to a test failure. I will try to update the internal code/test to fix the memory leak.

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@sonfire186
Copy link

@qlzh727 So is the problem solved? Can I put the second version?

@qlzh727
Copy link
Member

qlzh727 commented Feb 5, 2020

yes, the issue is resolved by d064c6f

mabounassif pushed a commit to mabounassif/tensorflow that referenced this issue Feb 13, 2020
Fix tensorflow#33178.

PiperOrigin-RevId: 292043221
Change-Id: Ife2fa9a2adf50424bb2c932044fe3db5f4bb42d5
@hunter2048
Copy link

hello, I still have this problem in google colab pro , OOM occured when i was running the code . And once i run ,error occured ,then GPU used 15.08GB/16.00GB and it cannot get cleared, i have to stop my jupyternote to recreate a session to run . Can someone tell me how to solve this problem ?

@Hhessling
Copy link

Hhessling commented Aug 11, 2020

Hello, i can confirm that this issue is not fixed, even in resent builds.

A simpel model like:

img_input = layers.Input(shape=(16, 640, 480, 3), name="input")
x = layers.BatchNormalization()(img_input)
img_encoder = efn.EfficientNetB0(input_shape=( 640, 480, 3), include_top=False)
act_60_40_144 = img_encoder.get_layer('block3a_activation')
img_encoder = Model(inputs=img_encoder.input, outputs=act_60_40_144.output)

x = layers.TimeDistributed(img_encoder)(x)
x = layers.TimeDistributed(layers.GlobalAvgPool2D())(x)
x = layers.Bidirectional(layers.GRU(x.shape[-1], return_sequences=True), 'sum')(x)

x = layers.Conv1D(3, 1, activation='softmax')(x)
mo = Model(inputs=img_input, outputs=x)

needs ~ 25GB, even with resent nightly builds of tensorflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests