Open
Description
When I run the second code block, I encounter the following exception.
AttributeError: 'tuple' object has no attribute 'rank'
environment:
tensorflow=2.12.0
capsa=0.1.5
Starting epoch 1/6
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-7-830de383472c>](https://localhost:8080/#) in <cell line: 5>()
8 # Get a batch of training data and compute the training step
9 for step, data in enumerate(train_loader):
---> 10 metrics = wrapper.train_step(data)
11 if step % 100 == 0:
12 print(step)
2 frames
[/usr/local/lib/python3.10/dist-packages/capsa/bias/histogramvae.py](https://localhost:8080/#) in train_step(self, data, prefix)
225 with tf.GradientTape() as t:
226 metric_loss, y_hat,bias = self.loss_fn(x, y)
--> 227 compiled_loss = self.compiled_loss(
228 y, y_hat, regularization_losses=self.losses
229 )
[/usr/local/lib/python3.10/dist-packages/keras/engine/compile_utils.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight, regularization_losses)
261 continue
262
--> 263 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
264 sw = losses_utils.apply_mask(y_p, sw, losses_utils.get_mask(y_p))
265 loss_value = loss_obj(y_t, y_p, sample_weight=sw)
[/usr/local/lib/python3.10/dist-packages/keras/engine/compile_utils.py](https://localhost:8080/#) in match_dtype_and_rank(y_t, y_p, sw)
829 def match_dtype_and_rank(y_t, y_p, sw):
830 """Match dtype and rank of predictions."""
--> 831 if y_t.shape.rank == 1 and y_p.shape.rank == 2:
832 y_t = tf.expand_dims(y_t, axis=-1)
833 if sw is not None:
AttributeError: 'tuple' object has no attribute 'rank'
Maybe there're some problems with dataloader or capsa library.
I get around this problem by modifying the histogramvae.py in the capsa library, which can be accessed by clicking the second link in the exception stack.
#@tf.function
def train_step(self, data, prefix=None):
# document omitted
x, y = data
y = tf.convert_to_tensor(y, dtype=tf.float32) # Add this code into the line 224
with tf.GradientTape() as t:
metric_loss, y_hat,bias = self.loss_fn(x, y)
compiled_loss = self.compiled_loss(
y, y_hat, regularization_losses=self.losses
)
loss = metric_loss + compiled_loss
.....
return keras_metrics
Metadata
Metadata
Assignees
Labels
No labels