## To Keras Tensorflow

In [None]:
%pip install tensorflow transformers nobuco

# required install torch on https://pytorch.org/get-started/locally/

In [None]:
import nobuco
from nobuco import ChannelOrder, ChannelOrderingStrategy
import torch
import torch.nn.functional as F
import tensorflow as tf
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

In [2]:
@nobuco.converter(F.softplus, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def softplus(input: torch.Tensor):
    return lambda input: tf.keras.activations.softplus(input)

In [3]:
@nobuco.converter(torch.einsum, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
def converter_einsum(*args):
    def func(*args):
        equation = args[0]
        operands = args[1:]
        return tf.einsum(equation, *operands)
    return func

In [4]:
args = ModelArgs(
    d_model=5,
    n_layer=1,
    vocab_size=50277
)
model = Mamba(args)
model.eval()
export_name = "mamba_minimal_1_layer"
dummy_input = "Test"
input_ids = tokenizer(dummy_input, return_tensors='pt').input_ids

keras_model = nobuco.pytorch_to_keras(
    model,
    args=[input_ids], kwargs=None,
    input_shapes={input_ids: (1, None)}, # Annotate dynamic axes with None
    inputs_channel_order=ChannelOrder.TENSORFLOW,
    outputs_channel_order=ChannelOrder.TENSORFLOW,
    constants_to_variables=False,
    trace_shape=True,
    save_trace_html=True
)

2024-02-19 19:47:13.408601: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:09:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-02-19 19:47:13.409221: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2024-02-19 19:47:13.481166: I external/local_xla/xla/service/service.cc:168] XLA service 0xea8a220 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2024-02-19 19:47:13.481192: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1708368433.488287   28138 device_compiler.h:18

Legend:
    [32mGreen[0m — conversion successful
    [33mYellow[0m — conversion imprecise
    [31mRed[0m — conversion failed
    [31m[7mRed[0m — no converter found
    [0m[1mBold[0m — conversion applied directly
    * — subgraph reused
    [7mTensor[0m — this output is not dependent on any of subgraph's input tensors
    [4mTensor[0m — this input is a parameter / constant
    [90mTensor[0m — this tensor is useless

[32mMamba[model][0m(int64_0<1,1>[0m) -> float32_87<1,1,50280>[0m
[32m │ [0m [32mEmbedding[torch.nn.modules.sparse][0m(int64_0<1,1>[0m) -> float32_2<1,1,5>[0m
[32m │ [0m [32m └·[0m [32m[1membedding[torch.nn.functional][0m(int64_0<1,1>[0m, [4mfloat32_1<50280,5>[0m, None, None, 2.0, False, False) -> float32_2<1,1,5>[0m
[32m │ [0m [32mResidualBlock[model][0m(float32_2<1,1,5>[0m) -> float32_79<1,1,5>[0m
[32m │ [0m [32m │ [0m [32mRMSNorm[model][0m(float32_2<1,1,5>[0m) -> float32_9<1,1,5>[0m
[32m │ [0m [32m │ [0m [32m │ [0m

In [None]:
tf.keras.models.save_model(keras_model, f'{export_name}.keras')
tf.keras.models.save_model(keras_model, f'{export_name}.h5')
tf.saved_model.save(keras_model, f'{export_name}')

In [7]:
# Inference Test for nobuco converted model
# prompt
dummy_prompt_keras = "Harry Potter test"
input_ids_keras = tokenizer(dummy_prompt_keras, return_tensors='tf').input_ids

#input_ids_keras = tf.cast(input_ids_keras, tf.int64)
# inference
out = keras_model.predict(input_ids_keras)
# output
print(out)

[[[-3.0078745   3.9671001  -2.305275  ]
  [-1.5027751  -1.864156    2.8154929 ]
  [-2.2186959   5.2344885  -1.0442678 ]
  ...
  [-4.8536053   5.0152903  -3.8399842 ]
  [-0.11023474  0.26254967  2.5834856 ]
  [ 1.321992   -1.8405054  -1.6095277 ]]]


## SavedModel 

In [7]:
# prompt
dummy_prompt_keras = "Harry Potter"
input_ids_keras = tokenizer(dummy_prompt_keras, return_tensors='tf').input_ids

# loading model
export_path = "mamba_minimal_1_layer" # No '.h5' in the path!
keras_model_restored = tf.saved_model.load(export_path)

input_ids_keras = tf.cast(input_ids_keras, tf.int64)

# inference
out = keras_model_restored(input_ids_keras)

# output
print(out)

tf.Tensor(
[[[ 2.1814482   5.0091105 ]
  [-0.9071193  -2.4775052 ]
  [-2.2117083  -0.39902794]
  ...
  [-0.49433738  1.2557342 ]
  [-2.74676    -0.9496337 ]
  [ 1.0100164  -4.3799496 ]]], shape=(1, 50280, 2), dtype=float32)


## .keras

In [10]:
# prompt
dummy_prompt_keras = "Harry Potter"
input_ids_keras = tokenizer(dummy_prompt_keras, return_tensors='tf').input_ids

# loading model
export_name = "mamba_minimal_1_layer"
keras_model = tf.keras.saving.load_model (f'{export_name}.keras', safe_mode=False)

input_ids_keras = tf.cast(input_ids_keras, tf.int64)

# inference
out = keras_model.predict(input_ids_keras)

# output
print(out)

AttributeError: Exception encountered when calling layer "lambda_12" (type Lambda).

'list' object has no attribute 'shape'

Call arguments received by layer "lambda_12" (type Lambda):
  • inputs=['tf.Tensor(shape=(1, None, 10), dtype=float32)', [['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)']]]
  • mask=None
  • training=None

## .h5

In [11]:
# prompt
dummy_prompt_keras = "Harry Potter"
input_ids_keras = tokenizer(dummy_prompt_keras, return_tensors='tf').input_ids

# loading model
export_name = "mamba_minimal_1_layer"
keras_model = tf.keras.saving.load_model (f'{export_name}.h5')

input_ids_keras = tf.cast(input_ids_keras, tf.int64)

# inference
out = keras_model.predict(input_ids_keras)

# output
print(out)

AttributeError: Exception encountered when calling layer "lambda_12" (type Lambda).

'list' object has no attribute 'shape'

Call arguments received by layer "lambda_12" (type Lambda):
  • inputs=['tf.Tensor(shape=(1, None, 10), dtype=float32)', [['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)'], ['tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)', 'tf.Tensor(shape=(), dtype=float32)']]]
  • mask=None
  • training=None