Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference in training accuracy and loss using gradientTape vs model.fit with binary_accuracy: A bug? #35585

Closed
amjass12 opened this issue Jan 4, 2020 · 28 comments
Assignees
Labels
comp:apis Highlevel API related issues comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author type:bug Bug

Comments

@amjass12
Copy link

amjass12 commented Jan 4, 2020

Hi all,

I am running a training loop using gradientTape which works well, however I am getting different training accuracy metrics when training using the gradientTape loop vs a straight model.fit method. I apologise if this should be a question for stack overflow, however, to the best of my knowledge the parameters are the same and therefore should be producing exactly the same results (or very close at least).. I therefore think there may be a bug and if any one can help me elucidate this i would really appreciate it!

I have prepared a sequential model as follows:

model=tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=64, input_dim=5078, activation="relu"))
model.add(tf.keras.layers.Dense(units=32, activation="relu"))
model.add(tf.keras.layers.Dense(units=100, activation="relu"))
model.add(tf.keras.layers.Dense(units=24, activation="sigmoid"))

and for the model.fit method, fit as follows:

model.compile(optimizer="Adam", loss="binary_crossentropy", metrics=["acc"])

model.fit(X_train, y_train,
 batch_size=32,
 epochs=100, verbose=1,
 validation_split=0.15,
 shuffle=True)

This works well and produces the following results (please note 100 epochs is overkill and the model overfits, however this is just to keep the same epochs as the as the gradientTape loop, otherwise there would be an early-stopping callback normally...

The model metrics are as follows:

 32/119 [=======>......................] - ETA: 0s - loss: 0.0699 - acc: 0.9753
119/119 [==============================] - 0s 168us/sample - **loss: 0.0668** - acc: **0.9779** - val_loss: **0.2350** - val_acc: **0.9048**

This is the expected behaviour (minus the overfitting)... Now when I create the gradientTape loop as follows, the accuracy metrics are of by about ~4-5% during the same 100 epochs, and the reason i suspect a bug is because i believe i am using the appropriate metrics:

def random_batch(X,y, batch_size=32):
    idx= np.random.randint(len(X), size=batch_size)
    return X[idx], y[idx]

##Further split train data to training set and validation set

X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.15, random_state=1)

##Run autodiff on model

n_epochs=100
batch_size=32
n_steps=len(X_train)//batch_size

optimizer=tf.keras.optimizers.Adam()
loss=tf.keras.losses.BinaryCrossentropy()

metricLoss=tf.keras.metrics.BinaryCrossentropy()
metricsAcc=tf.keras.metrics.BinaryAccuracy()

val_acc_metric=tf.keras.metrics.BinaryAccuracy()
val_acc_loss=tf.keras.metrics.BinaryCrossentropy()


train_loss_results = []
train_accuracy_results = []

validation_loss_results = []
validation_accuracy_results = []

# for loop iterate over epochs
for epoch in range(n_epochs):

    print("Epoch {}/{}".format(epoch, n_epochs))

    # for loop iterate over batches
    for step in range(1, n_steps + 1):
        X_batch, y_batch=random_batch(X_train.values, y_train)

        # gradientTape autodiff
        with tf.GradientTape() as tape:
            y_pred=model(X_batch, training=True)
            loss_values=loss(y_batch, y_pred)
        gradients=tape.gradient(loss_values, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        metricLoss(y_batch, y_pred)
        metricsAcc.update_state(y_batch, y_pred)

        # Loss and accuracy
        train_loss_results.append(loss_values)
        train_accuracy_results.append(metricsAcc.result())

        # Read out training results
        readout = 'Epoch {}, Training loss: {}, Training accuracy: {}'
        print(readout.format(epoch + 1, loss_values,
                              metricsAcc.result() * 100))

        metricsAcc.reset_states

        # Run a validation loop at the end of each epoch

    for valbatch in range(1+ n_steps +1):
        X_batchVal, y_batchVal = random_batch(X_val.values, y_val)

        val_logits = model(X_batchVal)
        # Update val metrics
        val_acc_metric(y_batchVal, val_logits)
        val_acc = val_acc_metric.result()

        val_acc_metric.update_state(y_batchVal, val_logits)

        val_loss=val_acc_loss(y_batchVal, val_logits)

        validation_loss_results.append(val_loss)
        validation_accuracy_results.append(val_acc_metric.result())

        # Read out validation results
        print( 'Validation loss: ' , float(val_loss),'Validation acc: %s' % (float(val_acc * 100),) )

        val_acc_metric.reset_states()

When i run this code, it works fine, and the iterations update the states of the accuracy and loss: however, the training accuracy is much lower than the model.fit method, after running also for 100 epochs: showing final epoch result that is printed (each same epoch is iterating over each batch):

Epoch 100, Training loss: 0.027735430747270584, Training accuracy: 93.6534423828125
Epoch 100, Training loss: 0.03832387551665306, Training accuracy: 93.67249298095703
Epoch 100, Training loss: 0.035500235855579376, Training accuracy: 93.69097900390625
Validation loss: 0.3204055726528168 Validation acc: 90.36458587646484
Validation loss: 0.32066160440444946 Validation acc: 89.71354675292969
Validation loss: 0.32083287835121155 Validation acc: 90.49479675292969
Validation loss: 0.3209479749202728 Validation acc: 90.10416412353516
Validation loss: 0.32088229060173035 Validation acc: 90.625

As you can see, the training accuracy is ~4-5% lower compared to the model.fit method. The loss records fine, and also, the validation data looks pretty much just like the validation data in the model.fit method.

Additionally, when i plot accuracy and loss in both model.fit and geadientTape methods, the shape of the curves look pretty much the same, and they both begin to overfit at similar points! but again, there is a huge discrepancy in the training accuracy.

I have specified the adam optimizer as well binary_crossentropy loss in model.fit and gradientTape. For model.fit, when I specific 'accuracy' or 'acc' for metrics, my understanding is that it will call on the binary_accuracy for calculating the accuracy. So as far as I am aware the parameters are similar that results should be fairly similar.

Additionally, when i call model.compile after training the model with gradientTape just to confirm evaluation, the results are slightly different again and look more like the model.fit method:

**Training**
model.compile(optimizer=optimizer, loss=tf.keras.losses.binary_crossentropy, metrics=['acc'])
print('\n', model.evaluate(X_train, y_train, verbose=1)[1])

32/101 [========>.....................] - ETA: 0s - loss: 0.0336 - acc: 0.9948
101/101 [==============================] - 0s 307us/sample - **loss: 0.0330 - acc: 0.9942**

**Validation**
model.compile(optimizer=optimizer, loss=tf.keras.losses.binary_crossentropy, metrics=['acc'])
print('\n', model.evaluate(X_val, y_val, verbose=1)[1])

18/18 [==============================] - 0s 111us/sample - **loss: 0.3879 - acc: 0.9028**

Now model.evaluate shows a loss and accuracy that are very similar to the model.fit method when i call evaluate on X_train and y_train. This is why i am suspect of a bug? Interestingly, the model.evaluate on validation data look similar to the gradientTape loop which leaves me really confused as i am therefore unsure of the true training accuracy and loss!

If anyone can help i would really appreciate this... I am happy to provide further code upstream of the model etc.. Again, apologies if this is not a bug but this seems really confusing to me like an incorrect behaviour...

@oanush oanush self-assigned this Jan 6, 2020
@oanush oanush added comp:keras Keras related issues comp:ops OPs related issues labels Jan 6, 2020
@HuggingLLM
Copy link

HuggingLLM commented Jan 7, 2020

Same Problem #35533

@oanush oanush added the type:bug Bug label Jan 7, 2020
@oanush oanush assigned gowthamkpr and unassigned oanush Jan 7, 2020
@amjass12
Copy link
Author

amjass12 commented Jan 8, 2020

@NLP-ZY is this a known problem? I didn't understand your comment about it being caused by the relu activation? Is this a known bug? If so is there a solution?

@HuggingLLM
Copy link

@amjass12 There are two problem. First caused by relu activation, when I use model.fit, the result is perfect in acc & loss & val_acc & val_loss, but when I use GradientTape, the acc & loss is unusual, abount 8 epoch, every batch predictions is same. Seceond, when I remove relu activation, every batch predictions become usual, the acc & loss as excepted, but the val_acc & val_loss also bad. Anyway, under the same network, the result of model.fit is much better than GradientTape on val_acc & val_loss.

@HuggingLLM
Copy link

Network of 13 classification, my code is #35533
Model.fit result op 30 epochs:
loss: 0.0050 - acc: 0.9987 - val_loss: 0.0716 - val_acc: 0.9908
Gradient result on 30 epochs:
loss: 0.0160 - acc: 0.9987 - val_loss: 4.0659 - val_acc: 0.8882

@amjass12
Copy link
Author

amjass12 commented Jan 12, 2020

Hi @NLP-ZY ,

Thank you for the reply and detailed information. This is very interesting, is this likely a bug?

With regards to removing the relu, can you expand on this please ..

Do you mean you use a different activation function? I would really like to fix this as soon as possible.

@HuggingLLM
Copy link

sorry to respond slowly, I am busy recently. remove relu means use activation=None. I have try many times, I think gradienttape maybe have some bugs, the result of gradienttape is worse than model.fit

@HuggingLLM
Copy link

HuggingLLM commented Jan 17, 2020

# -*- coding: utf8 -*-

import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow import keras as tfk


class Test(tfk.Model):
    def __init__(self):
        super(Test, self).__init__()
        self.embedding_layer = tfk.layers.Embedding(50000, 300)
        self.conv1d_layer = tfk.layers.Conv1D(256, 5)
        self.pool_layer = tfk.layers.MaxPool1D(pool_size=5, strides=2)
        self.dense1_layer = tfk.layers.Dense(128, activation=tfk.activations.relu)
        self.dense2_layer = tfk.layers.Dense(10, activation=tfk.activations.softmax)

    def call(self, inputs, training=None, mask=None):
        hidden = self.embedding_layer(inputs)
        hidden = self.conv1d_layer(hidden)
        hidden = self.pool_layer(hidden)
        hidden = tfk.layers.Flatten()(hidden)
        hidden = self.dense1_layer(hidden)
        y_pred = self.dense2_layer(hidden)
        return y_pred


class Test2(tfk.Model):
    def __init__(self):
        super(Test2, self).__init__()
        self.embedding_layer = tfk.layers.Embedding(50000, 300)
        self.rnn_layer = tfk.layers.LSTM(200)
        self.dense1_layer = tfk.layers.Dense(128, activation=tfk.activations.relu)
        self.dense2_layer = tfk.layers.Dense(10, activation=tfk.activations.softmax)

    def call(self, inputs, training=None, mask=None):
        hidden = self.embedding_layer(inputs)
        hidden = self.rnn_layer(hidden)
        hidden = self.dense1_layer(hidden)
        y_pred = self.dense2_layer(hidden)
        return y_pred


gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
epochs = 30
x = np.random.randint(low=0, high=50000, size=(10000, 128))
y = np.random.randint(low=0, high=10, size=(10000,))
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
trainset = tf.data.Dataset.zip(
    (tf.data.Dataset.from_tensor_slices(x_train), tf.data.Dataset.from_tensor_slices(y_train))).batch(300)
valset = tf.data.Dataset.zip(
    (tf.data.Dataset.from_tensor_slices(x_val), tf.data.Dataset.from_tensor_slices(y_val))).batch(300)
# model = Test()
model = Test2()
train_acc = tf.metrics.SparseCategoricalAccuracy()
val_acc = tf.metrics.SparseCategoricalAccuracy()
train_loss = tf.metrics.Mean()
val_loss = tf.metrics.Mean()
loss_object = tf.losses.SparseCategoricalCrossentropy()
optimizer = tf.optimizers.Adam()

model.compile(optimizer=optimizer, loss=loss_object, metrics=[train_acc])
model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=300, epochs=epochs)

# model = Test()
model = Test2()


@tf.function
def train_op(x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss = loss_object(y, y_pred)
        train_loss.update_state(loss)
        train_acc.update_state(y, y_pred)
        tf.print(y_pred)
        grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))


@tf.function
def val_op(x, y):
    y_pred = model(x)
    loss = loss_object(y, y_pred)
    val_loss.update_state(loss)
    val_acc.update_state(y, y_pred)


for epoch in range(epochs):
    print("Epoch {}/{}".format(epoch + 1, epochs))
    bar = tfk.utils.Progbar(len(y_train), unit_name="sample", stateful_metrics={"loss", "acc"})
    log_values = []
    for batch_x, batch_y in trainset:
        train_op(batch_x, batch_y)
        log_values.append(("loss", train_loss.result().numpy()))
        log_values.append(("acc", train_acc.result().numpy()))
        bar.add(len(batch_y), log_values)
    for batch_x, batch_y in valset:
        val_op(batch_x, batch_y)
    print("val_loss -", val_loss.result().numpy(), "val_acc -", val_acc.result().numpy())

@amjass12 this is my test code, under the same network, the result of model.fit is better than gradienttape

@HuggingLLM
Copy link

@gowthamkpr can you help us to figure out ? Is there some implicit operation in model.fit ? why there is much diff in model.fit vs gradienttape

@amjass12
Copy link
Author

amjass12 commented Jan 17, 2020

Hi @NLP-ZY ,

No worries! Yes this is interesting behaviour indeed! I assume specifying activation =None means that it uses no activation function and is therefore linear model?

I would really like to understand this behaviour and get to the bottom of this as I need gradient tape for a shared dense layer!

@wmmxk
Copy link

wmmxk commented Jan 26, 2020

I also run into the same problem. The loss is far larger when I updated the weights using tf.GradientTape than calling model.fit. I created a reproducible example in Colab. Could anyone have a look? It would take 2 minutes to reproduce the problem.

@Costyv95
Copy link

Costyv95 commented Feb 4, 2020

Do you have any updates on this issue? Like what's causing it? Maybe a quick fix?

@brian36
Copy link

brian36 commented Feb 7, 2020

I have the exact same problem - consistently much worse results using gradient tape as opposed to model fit, for the exact same network and training. The network I'm using has batch norm, relu, and dropout.

@amjass12
Copy link
Author

amjass12 commented Feb 7, 2020

I am glad this problem is reproducible, maybe someone can look into this?

@brian36
Copy link

brian36 commented Feb 7, 2020

yes right now because of this I have to give up on tensorflow 2 and revert to 1.x - may consider pytorch as well

@gowthamkpr gowthamkpr added comp:apis Highlevel API related issues and removed comp:ops OPs related issues labels Feb 20, 2020
@gowthamkpr gowthamkpr assigned pavithrasv and unassigned gowthamkpr Feb 20, 2020
@amjass12
Copy link
Author

Any update on this? it would be nice to make some progress with this as i will need to use this soon for some unequal length input and merged layers :D

thanks!

@amalroy2016
Copy link

Experiencing the same here. Any updates on when to expect a fix???

@bentyeh
Copy link

bentyeh commented Apr 16, 2020

Please take a look at issue #38596, linked above. I believe this is the underlying issue for why the Keras model.fit() method computes the wrong epoch loss.

@jvishnuvardhan
Copy link
Contributor

jvishnuvardhan commented Apr 17, 2020

@NLP-ZY I can reproduce your error. Gist is here. The reasons behind the discrepancy is mentioned here.

In short,

  1. due to numerical instability issues, we need to remove softmax from the final layer of the model.
  2. When you specify loss_object as tf.keras.losses.SparseCategoricalCrossentropy(), Under the hood, model.fit uses tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True). However, in the custom training with tf.function decoration, we need to explicitly specify it as tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

I made those changes and ran your code for 25 epochs. I get 93.24% training accuracy. Please take a look at the gist. Thanks.

Please verify once and close the issue if this was resolved for you. Thanks!

@jvishnuvardhan
Copy link
Contributor

@amjass12 Can you please follow the steps mentioned above and verify whether it resolved the issue for you. If your issue was not resolved, can you please post a complete standalone code to reproduce your issue.

If the issue was resolved, then please close the issue. Thanks!

@HuggingLLM
Copy link

@jvishnuvardhan Thanks for your reply. I have try to use tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) replace tf.keras.losses.SparseCategoricalCrossentropy();
the loss of GradientTape is decrease slower than Model.fit too, this can't solve my problem.
I want to figure out how can I use GradientTape get the same loss & acc with model.fit on the same epochs, because there are some complex model can't use model.fit. I cat get loss: 0.0087 - acc: 0.9998 on epoch 5(model.fit); but loss: 0.9447 - acc: 0.6677 on epoch 5(GradientTape). Maybe when I use GradientTape, I should train more epochs ?

@HuggingLLM
Copy link

@jvishnuvardhan I try to set epoch=100 for GradientTape and set epoch=5 for mode.fit. I get loss: 0.0030 - acc;: 1.0000 on epoch 5(model.fit), get loss: 1.4832 - acc: 0.9774 on epoch 100 (GradientTape). you can see my result on the gist. Please help us figure out why GradientTape can't work as Model.fit.

@amalroy2016
Copy link

Does using tf.gradients instead GradientTape make a difference in your cases??

@HuggingLLM
Copy link

@amalroy2016 Thanks, I have a try, but it doesn't work.

@amjass12
Copy link
Author

@jvishnuvardhan

Many thanks for the message I will do this and report back asap!

@jvishnuvardhan
Copy link
Contributor

@amjass12 Did you had a time to look into my gist. Please also check the solution (using reset_states) that worked for @NLP-ZY.

Please verify once and close the issue. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting response Status - Awaiting response from author label May 21, 2020
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 28, 2020
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author type:bug Bug
Projects
None yet
Development

No branches or pull requests