In [2]:
import tensorflow as tf

In [3]:
class Test:
    def __init__(self, a: str, b: tf.Tensor) -> None:
        self.a = a
        self.b = b

@tf.function
def test(x: tf.Tensor, test_obj: Test):
    print("Tracing")
    return x + test_obj.b

In [4]:
test_obj = Test('hello', tf.constant(3))

test(tf.constant(1), test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [5]:
test(tf.constant(1), test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [6]:
# Expected 5 but got 4 instead
test_obj.b = tf.constant(4)
test(tf.constant(1), test_obj)


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [7]:
test_obj = Test('hello', tf.constant(4))
test(tf.constant(1), test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=5>

In [8]:
class Test2(tf.experimental.ExtensionType):
    a: tf.Tensor
    b: tf.Tensor
    

@tf.function
def test2(x: tf.Tensor, test_obj: Test2):
    print("Tracing")
    return x + test_obj.b

In [9]:
test_obj = Test2('hello', tf.constant(3))
test2(tf.constant(1), test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [10]:
test_obj = Test2('hello', tf.constant(5))
test2(tf.constant(1), test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=6>

In [11]:
test_obj = Test2('goodbye', tf.constant(3))
test2(tf.constant(1), test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [12]:
class Test3(tf.experimental.ExtensionType):
    a: str
    b: tf.Tensor
    

@tf.function
def test3(x: tf.Tensor, test_obj: Test3):
    print("Tracing")
    return x + test_obj.b

In [13]:
test_obj = Test3('hello', tf.constant(3))
test3(tf.constant(1), test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [14]:
test_obj = Test3('hello', tf.constant(5))
test3(tf.constant(1), test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=6>

In [15]:
test_obj = Test3('goodbye', tf.constant(3))
test3(tf.constant(1), test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [16]:
from dataclasses import dataclass

@dataclass
class Test4(tf.experimental.ExtensionType):
    a: str
    b: tf.Tensor
    

@tf.function
def test4(x: tf.Tensor, test_obj: Test4):
    print("Tracing")
    return x + test_obj.b

In [17]:
from dataclasses import asdict


test_obj = Test4('hello', tf.constant(3))
print(asdict(test_obj))
test4(tf.constant(1), test_obj)

{'a': 'hello', 'b': <tf.Tensor: shape=(), dtype=int32, numpy=3>}
Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [19]:
test_obj = Test4('hello', tf.constant(5))
test4(tf.constant(1), test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=6>

In [20]:
test_obj = Test4('goodbye', tf.constant(3))
test4(tf.constant(1), test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [21]:
from dataclasses import asdict, is_dataclass

asdict(test_obj)

{'a': 'goodbye', 'b': <tf.Tensor: shape=(), dtype=int32, numpy=3>}

In [1]:
@dataclass
class Test4(tf.experimental.ExtensionType):
    a: str
    b: tf.Tensor
    c: tf.Tensor
    
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.c = f"{a}_1"


NameError: name 'dataclass' is not defined

In [27]:
Test4('goodbye', tf.constant(3))

Test4(a='goodbye', b=<tf.Tensor: shape=(), dtype=int32, numpy=3>, c=<tf.Tensor: shape=(), dtype=string, numpy=b'goodbye_1'>)

In [31]:
class Hello:
    
    @staticmethod
    @tf.function
    def apply(x, test_obj):
        print("Tracing")
        return x + test_obj.b

    def run(self, test_obj):
        return self.apply(tf.constant(1), test_obj)

In [32]:
test_obj = Test("hello", tf.constant(3))
hello = Hello()
hello.run(test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [33]:
hello.run(test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [34]:
test_obj = Test("hello", tf.Variable(3))
hello = Hello()
hello.run(test_obj)

Tracing


<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [35]:
test_obj.b.assign(4)

<tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=4>

In [36]:
hello.run(test_obj)

<tf.Tensor: shape=(), dtype=int32, numpy=5>