Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Avoid double lookup of tables when using ShadowVariable when the Optimizer is updating gradients #262

Merged
merged 3 commits into from
Aug 22, 2022

Conversation

MoFHeka
Copy link
Contributor

@MoFHeka MoFHeka commented Jul 7, 2022

Description

Before this fix, in Keras model, ShadowVariable would read_value with do_prefetch twice when both embedding_lookup in forward calculation and apply_grad_to_update_var in backward calculation. For now, if var in apply_grad_to_update_var is ShadowVariable, read its value directly because it has been already done read_value with do_prefetch when in embedding_lookup function.

Also fix pass parameter init_size when create slot variable for de.Variable.

Also compatible with CUDA 11.6

Also modify _convert_anything_to_init function:
1.Make output shape is dim when raw_init is a TF Initializer
2.Make input dim is a constant op when using reshape op to prevent fault ——
*** tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 1 values, but the requested shape has 43
Encountered when executing an operation using EagerExecutor. This error cancels all future operations and poisons their output tensors.

Also restrict Bazel building ram resources for Github CI memory limit

Type of change

  • Bug fix
  • New Tutorial
  • Updated or additional documentation
  • Additional Testing
  • New Feature

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running yapf
    • By running clang-format
  • This PR addresses an already submitted issue for TensorFlow Recommenders-Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works

How Has This Been Tested?

A test python script

import tensorflow as tf
import numpy as np
import tensorflow_recommenders_addons as tfra
from tensorflow.keras import models, losses, layers, metrics, initializers

inp = {
    "f1": np.array([1,2,1,2,1,3,2,3]),
    # "f2": np.array([
    #     [1,2,3],
    #     [3,4,5],
    #     [1,2,3],
    #     [3,4,5],
    #     [3,2,5],
    #     [3,1,6],
    #     [2,1,2],
    #     [3,6,4]
    # ])
    "f2": np.array([
        [1,2,3],
        [3,4,5],
        [1,2,3],
        [3,4,5],
        [1,2,3],
        [3,4,5],
        [1,2,3],
        [3,4,5]
    ])
}

label = np.array([1,1,1,1,1,1,1,1])
merge_ds = tf.data.Dataset.from_tensor_slices((inp, label))

class TestModelTwoFeatures(tf.keras.Model):
    def __init__(self):
        super(TestModelTwoFeatures, self).__init__()
        self.embedding_size=5
        self.embed_layer = tfra.dynamic_embedding.get_variable(
            name="embed_layer",
            dim=self.embedding_size,
            devices=["CPU:0"],
            # initializer=initializers.RandomNormal(0, 0.1)
            initializer=1
        )
        self.dense = layers.Dense(1, activation="sigmoid")
    
    def call(self, batch):
        f1, f2 = batch["f1"], batch["f2"]
        f1_weights, f1_tw = tfra.dynamic_embedding.embedding_lookup(
            params=self.embed_layer,
            ids=f1,
            name="embed_weights_f1",
            return_trainable=True
        )
       
        f2_weights, f2_tw = tfra.dynamic_embedding.embedding_lookup(
            params=self.embed_layer,
            ids=f2,
            name="embed_weights_f2",
            return_trainable=True
        )
        
        f2_weights_merged = layers.GlobalAveragePooling1D()(f2_weights)
        
        final_weights = tf.concat([f1_weights, f2_weights_merged], axis=-1)
        output = self.dense(final_weights)
        return f1_weights, f2_weights, output, [f1_tw, f2_tw]

model = TestModelTwoFeatures()

optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(optimizer)

for i, (batch, label) in enumerate(merge_ds.batch(2)):
    print(f"batch No. = {i}")
    print(f"batch data = {batch}")
    print(f"batch_label = {label}")
    with tf.GradientTape() as tape:
        f1_weights, f2_weights, output, tws = model(batch)
        loss = losses.BinaryCrossentropy(from_logits=True)(label, output)
    grads = tape.gradient(loss, model.trainable_variables + tws)
    optimizer.apply_gradients(zip(grads, model.trainable_variables + tws))
    print(f"f1_weights = {f1_weights}")
    print(f"f2_weights = {f2_weights}")
    print()

@MoFHeka MoFHeka requested a review from rhdong as a code owner July 7, 2022 11:28
@MoFHeka MoFHeka requested a review from Lifann July 7, 2022 11:30
@MoFHeka MoFHeka changed the title [fix] Avoid double lookup of tables when using ShadowVariable when the Optimizer is updating gradients [Fix] Avoid double lookup of tables when using ShadowVariable when the Optimizer is updating gradients Jul 7, 2022
@@ -101,7 +101,10 @@ def apply_grad_to_update_var(var, grad):
var._track_optimizer_slots(_slots)

with ops.control_dependencies([grad]):
v0 = var.read_value(do_prefetch=not var.params.bp_v2)
if isinstance(var, de.shadow_ops.ShadowVariable):
v0 = var.read_value(do_prefetch=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be ok if lookup multiple times in one function call, from different inputs?

Copy link
Member

@rhdong rhdong Jul 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Lifann , I remember @MoFHeka has said this was a fake issue, can be close?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Lifann , I remember @MoFHeka has said this was a fake issue, can be close?

Pass it then.

@MoFHeka MoFHeka force-pushed the redis-dev branch 2 times, most recently from 6c00505 to 7420ea5 Compare July 12, 2022 10:15
@@ -51,6 +51,14 @@ _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"

_DEFAULT_CUDA_COMPUTE_CAPABILITIES = {
"11.6": [
Copy link
Member

@Lifann Lifann Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe update the cuda compute capability in another commit.

Lifann
Lifann previously approved these changes Jul 15, 2022
Copy link
Member

@Lifann Lifann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…imizer is updating gradients.

[fix] pass parameter init_size when create slot variable for de.Variable
1.Make output shape is dim when raw_init is a TF Initializer
2.Make input dim is a constant op when using reshape op to prevent bug fault when tf.function.
Copy link
Member

@Lifann Lifann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve again.

@rhdong rhdong merged commit 41b3daf into tensorflow:master Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants