Skip to content

Commit 7acfb87

Browse files
zafaralitensorflower-gardener
authored andcommitted
Add copy and deepcopy functionality for resource_variable_ops.ResourceVariable.
PiperOrigin-RevId: 209823375
1 parent 401cb1a commit 7acfb87

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

tensorflow/python/keras/models_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,36 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import copy
2122
import os
2223

2324
import numpy as np
2425

2526
from tensorflow.python import keras
27+
from tensorflow.python.eager import context
2628
from tensorflow.python.framework import test_util
2729
from tensorflow.python.keras import metrics
2830
from tensorflow.python.keras import models
31+
from tensorflow.python.ops import random_ops
32+
from tensorflow.python.ops import resource_variable_ops
2933
from tensorflow.python.platform import test
3034
from tensorflow.python.training import adam
3135

3236

37+
class TestModel(keras.Model):
38+
"""A model subclass."""
39+
40+
def __init__(self, n_outputs=4, trainable=True):
41+
"""A test class with one dense layer and number of outputs as a variable."""
42+
super(TestModel, self).__init__()
43+
self.layer1 = keras.layers.Dense(n_outputs)
44+
self.n_outputs = resource_variable_ops.ResourceVariable(
45+
n_outputs, trainable=trainable)
46+
47+
def call(self, x):
48+
return self.layer1(x)
49+
50+
3351
class TestModelCloning(test.TestCase):
3452

3553
def test_clone_sequential_model(self):
@@ -187,6 +205,36 @@ def test_model_backend_float64_use_cases(self):
187205
keras.backend.set_floatx(floatx)
188206

189207

208+
class TestModelDeepCopy(test.TestCase):
209+
210+
def test_deep_copy_eager_mode_trainable(self):
211+
with context.eager_mode():
212+
x = random_ops.random_normal((32, 4))
213+
model = TestModel(trainable=True)
214+
model(x) # Initialize Variables.
215+
model_copy = copy.deepcopy(model)
216+
self.assertEqual(len(model_copy.trainable_variables), 3)
217+
model_copy.n_outputs.assign(1200)
218+
self.assertFalse(
219+
np.allclose(model_copy.n_outputs.numpy(),
220+
model.n_outputs.numpy()))
221+
222+
def test_deep_copy_eager_mode_not_trainable(self):
223+
with context.eager_mode():
224+
x = random_ops.random_normal((32, 4))
225+
model = TestModel(trainable=False)
226+
model(x)
227+
model_copy = copy.deepcopy(model)
228+
self.assertEqual(len(model_copy.trainable_variables), 2)
229+
230+
weights = model_copy.get_weights()
231+
weights = [w * 4 for w in weights]
232+
model_copy.set_weights(weights)
233+
self.assertFalse(
234+
np.allclose(model.get_weights()[0],
235+
model_copy.get_weights()[0]))
236+
237+
190238
class TestCloneAndBuildModel(test.TestCase):
191239

192240
def test_clone_and_build_non_compiled_model(self):

tensorflow/python/kernel_tests/resource_variable_ops_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import copy
2021
import gc
2122

2223
import numpy as np
@@ -106,6 +107,27 @@ def testEagerBool(self):
106107
v = resource_variable_ops.ResourceVariable(False, name="bool_test")
107108
self.assertAllEqual(bool(v), False)
108109

110+
def testEagerDeepCopy(self):
111+
with context.eager_mode():
112+
init_value = np.ones((4, 4, 4))
113+
variable = resource_variable_ops.ResourceVariable(init_value,
114+
name="init")
115+
116+
copied_variable = copy.deepcopy(variable)
117+
copied_variable.assign(4 * np.ones((4, 4, 4)))
118+
119+
# Copying the variable should create a new underlying tensor with distinct
120+
# values.
121+
self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy()))
122+
123+
def testGraphDeepCopy(self):
124+
with self.test_session():
125+
init_value = np.ones((4, 4, 4))
126+
variable = resource_variable_ops.ResourceVariable(init_value,
127+
name="init")
128+
with self.assertRaises(NotImplementedError):
129+
copy.deepcopy(variable)
130+
109131
@test_util.run_in_graph_and_eager_modes
110132
def testStridedSliceAssign(self):
111133
v = resource_variable_ops.ResourceVariable([1.0, 2.0])

tensorflow/python/ops/resource_variable_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,22 @@ def __nonzero__(self):
586586
def __bool__(self):
587587
return bool(self.read_value())
588588

589+
def __copy__(self):
590+
return self
591+
592+
def __deepcopy__(self, memo):
593+
if not context.executing_eagerly():
594+
raise NotImplementedError(
595+
"__deepcopy__() is only available when eager execution is enabled.")
596+
copied_variable = ResourceVariable(
597+
initial_value=self.read_value(),
598+
trainable=self._trainable,
599+
constraint=self._constraint,
600+
dtype=self._dtype,
601+
name=self._shared_name + "_copy")
602+
memo[self._unique_id] = copied_variable
603+
return copied_variable
604+
589605
@property
590606
def dtype(self):
591607
"""The dtype of this variable."""

0 commit comments

Comments
 (0)