This is a companion notebook for the book [Deep Learning with Python, Third Edition](https://www.manning.com/books/deep-learning-with-python-third-edition). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.

**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**

The book's contents are available online at [deeplearningwithpython.io](https://deeplearningwithpython.io).

In [None]:
!pip install keras keras-hub --upgrade -q

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"

In [None]:
# @title
import os
from IPython.core.magic import register_cell_magic

@register_cell_magic
def backend(line, cell):
    current, required = os.environ.get("KERAS_BACKEND", ""), line.split()[-1]
    if current == required:
        get_ipython().run_cell(cell)
    else:
        print(
            f"This cell requires the {required} backend. To run it, change KERAS_BACKEND to "
            f"\"{required}\" at the top of the notebook, restart the runtime, and rerun the notebook."
        )

## A deep dive on Keras

### A spectrum of workflows

### Different ways to build Keras models

#### The Sequential model

In [None]:
import keras
from keras import layers

model = keras.Sequential(
    [
        layers.Dense(64, activation="relu"),
        layers.Dense(10, activation="softmax"),
    ]
)

In [None]:
model = keras.Sequential()
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))

In [None]:
model.weights

In [None]:
model.build(input_shape=(None, 3))
model.weights

In [None]:
model.summary(line_length=80)

In [None]:
model = keras.Sequential(name="my_example_model")
model.add(layers.Dense(64, activation="relu", name="my_first_layer"))
model.add(layers.Dense(10, activation="softmax", name="my_last_layer"))
model.build((None, 3))
model.summary(line_length=80)

In [None]:
model = keras.Sequential()
model.add(keras.Input(shape=(3,)))
model.add(layers.Dense(64, activation="relu"))

In [None]:
model.summary(line_length=80)

In [None]:
model.add(layers.Dense(10, activation="softmax"))
model.summary(line_length=80)

#### The Functional API

##### A simple example

In [None]:
inputs = keras.Input(shape=(3,), name="my_input")
features = layers.Dense(64, activation="relu")(inputs)
outputs = layers.Dense(10, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs, name="my_functional_model")

In [None]:
inputs = keras.Input(shape=(3,), name="my_input")

In [None]:
inputs.shape

In [None]:
inputs.dtype

In [None]:
features = layers.Dense(64, activation="relu")(inputs)

In [None]:
features.shape

In [None]:
outputs = layers.Dense(10, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs, name="my_functional_model")

In [None]:
model.summary(line_length=80)

##### Multi-input, multi-output models

In [None]:
vocabulary_size = 10000
num_tags = 100
num_departments = 4

title = keras.Input(shape=(vocabulary_size,), name="title")
text_body = keras.Input(shape=(vocabulary_size,), name="text_body")
tags = keras.Input(shape=(num_tags,), name="tags")

features = layers.Concatenate()([title, text_body, tags])
features = layers.Dense(64, activation="relu", name="dense_features")(features)

priority = layers.Dense(1, activation="sigmoid", name="priority")(features)
department = layers.Dense(
    num_departments, activation="softmax", name="department"
)(features)

model = keras.Model(
    inputs=[title, text_body, tags],
    outputs=[priority, department],
)

##### Training a multi-input, multi-output model

In [None]:
import numpy as np

num_samples = 1280

title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))

priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, num_departments, size=(num_samples, 1))

model.compile(
    optimizer="adam",
    loss=["mean_squared_error", "sparse_categorical_crossentropy"],
    metrics=[["mean_absolute_error"], ["accuracy"]],
)
model.fit(
    [title_data, text_body_data, tags_data],
    [priority_data, department_data],
    epochs=1,
)
model.evaluate(
    [title_data, text_body_data, tags_data], [priority_data, department_data]
)
priority_preds, department_preds = model.predict(
    [title_data, text_body_data, tags_data]
)

In [None]:
model.compile(
    optimizer="adam",
    loss={
        "priority": "mean_squared_error",
        "department": "sparse_categorical_crossentropy",
    },
    metrics={
        "priority": ["mean_absolute_error"],
        "department": ["accuracy"],
    },
)
model.fit(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data},
    {"priority": priority_data, "department": department_data},
    epochs=1,
)
model.evaluate(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data},
    {"priority": priority_data, "department": department_data},
)
priority_preds, department_preds = model.predict(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data}
)

##### The power of the Functional API: Access to layer connectivity

###### Plotting layer connectivity

In [None]:
keras.utils.plot_model(model, "ticket_classifier.png")

In [None]:
keras.utils.plot_model(
    model,
    "ticket_classifier_with_shape_info.png",
    show_shapes=True,
    show_layer_names=True,
)

###### Feature extraction with a Functional model

In [None]:
model.layers

In [None]:
model.layers[3].input

In [None]:
model.layers[3].output

In [None]:
features = model.layers[4].output
difficulty = layers.Dense(3, activation="softmax", name="difficulty")(features)

new_model = keras.Model(
    inputs=[title, text_body, tags], outputs=[priority, department, difficulty]
)

In [None]:
keras.utils.plot_model(
    new_model,
    "updated_ticket_classifier.png",
    show_shapes=True,
    show_layer_names=True,
)

#### Subclassing the Model class

##### Rewriting our previous example as a subclassed model

In [None]:
class CustomerTicketModel(keras.Model):
    def __init__(self, num_departments):
        super().__init__()
        self.concat_layer = layers.Concatenate()
        self.mixing_layer = layers.Dense(64, activation="relu")
        self.priority_scorer = layers.Dense(1, activation="sigmoid")
        self.department_classifier = layers.Dense(
            num_departments, activation="softmax"
        )

    def call(self, inputs):
        title = inputs["title"]
        text_body = inputs["text_body"]
        tags = inputs["tags"]

        features = self.concat_layer([title, text_body, tags])
        features = self.mixing_layer(features)
        priority = self.priority_scorer(features)
        department = self.department_classifier(features)
        return priority, department

In [None]:
model = CustomerTicketModel(num_departments=4)

priority, department = model(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data}
)

In [None]:
model.compile(
    optimizer="adam",
    loss=["mean_squared_error", "sparse_categorical_crossentropy"],
    metrics=[["mean_absolute_error"], ["accuracy"]],
)
model.fit(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data},
    [priority_data, department_data],
    epochs=1,
)
model.evaluate(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data},
    [priority_data, department_data],
)
priority_preds, department_preds = model.predict(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data}
)

##### Beware: What subclassed models don't support

#### Mixing and matching different components

In [None]:
class Classifier(keras.Model):
    def __init__(self, num_classes=2):
        super().__init__()
        if num_classes == 2:
            num_units = 1
            activation = "sigmoid"
        else:
            num_units = num_classes
            activation = "softmax"
        self.dense = layers.Dense(num_units, activation=activation)

    def call(self, inputs):
        return self.dense(inputs)

inputs = keras.Input(shape=(3,))
features = layers.Dense(64, activation="relu")(inputs)
outputs = Classifier(num_classes=10)(features)
model = keras.Model(inputs=inputs, outputs=outputs)

In [None]:
inputs = keras.Input(shape=(64,))
outputs = layers.Dense(1, activation="sigmoid")(inputs)
binary_classifier = keras.Model(inputs=inputs, outputs=outputs)

class MyModel(keras.Model):
    def __init__(self, num_classes=2):
        super().__init__()
        self.dense = layers.Dense(64, activation="relu")
        self.classifier = binary_classifier

    def call(self, inputs):
        features = self.dense(inputs)
        return self.classifier(features)

model = MyModel()

#### Remember: Use the right tool for the job

### Using built-in training and evaluation loops

In [None]:
from keras.datasets import mnist

def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

model = get_mnist_model()
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
model.fit(
    train_images,
    train_labels,
    epochs=3,
    validation_data=(val_images, val_labels),
)
test_metrics = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)

#### Writing your own metrics

In [None]:
from keras import ops

class RootMeanSquaredError(keras.metrics.Metric):
    def __init__(self, name="rmse", **kwargs):
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(
            name="total_samples", initializer="zeros"
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = ops.one_hot(y_true, num_classes=ops.shape(y_pred)[1])
        mse = ops.sum(ops.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = ops.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        return ops.sqrt(self.mse_sum / self.total_samples)

    def reset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0.)

In [None]:
model = get_mnist_model()
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy", RootMeanSquaredError()],
)
model.fit(
    train_images,
    train_labels,
    epochs=3,
    validation_data=(val_images, val_labels),
)
test_metrics = model.evaluate(test_images, test_labels)

#### 代码解释

这段代码定义了一个自定义的 **Keras 指标（Metric）类** —— `RootMeanSquaredError`（RMSE，均方根误差），它继承自 `keras.metrics.Metric`。
我们逐行来解释：

---

##### 🌱 1️⃣ 导入与类定义

```python
from keras import ops

class RootMeanSquaredError(keras.metrics.Metric):
```

* `keras.metrics.Metric` 是所有自定义指标的基类。
* `ops` 是 Keras 的后端操作模块（兼容 TensorFlow、JAX、Torch），提供了如 `ops.sum()`、`ops.sqrt()` 等张量运算。

---

##### 🌿 2️⃣ 初始化方法 `__init__`

```python
def __init__(self, name="rmse", **kwargs):
    super().__init__(name=name, **kwargs)
    self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
    self.total_samples = self.add_weight(name="total_samples", initializer="zeros")
```

* `name="rmse"` 给指标命名，训练时可以用 `metrics=["rmse"]` 调用。
* `self.add_weight()` 定义两个**状态变量（state variables）**，在每个 batch 更新：

  * `mse_sum`：累计所有样本的平方误差和
  * `total_samples`：累计样本数
* 初始化为 0。

---

##### 🍃 3️⃣ 更新状态 `update_state`

```python
def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = ops.one_hot(y_true, num_classes=ops.shape(y_pred)[1])
    mse = ops.sum(ops.square(y_true - y_pred))
    self.mse_sum.assign_add(mse)
    num_samples = ops.shape(y_pred)[0]
    self.total_samples.assign_add(num_samples)
```

每处理一个 batch，Keras 会调用一次 `update_state()` 来更新指标值。

* `y_true = ops.one_hot(...)`
  将标签转为 one-hot 向量（因为预测输出是 softmax）。
  比如 y_true = `[1, 2]` → one-hot → `[[0,1,0], [0,0,1]]`

* `ops.square(y_true - y_pred)`：计算逐元素平方误差。

* `ops.sum(...)`：对整个 batch 求和，得到该 batch 的总平方误差。

* `assign_add()`：累加到全局的 `mse_sum` 和 `total_samples`。

---

##### 🌸 4️⃣ 计算结果 `result`

```python
def result(self):
    return ops.sqrt(self.mse_sum / self.total_samples)
```

* 计算整体 RMSE = √(MSE)，即均方误差的平方根。

---

##### 🍂 5️⃣ 重置状态 `reset_state`

```python
def reset_state(self):
    self.mse_sum.assign(0.)
    self.total_samples.assign(0.)
```

* 每个 epoch 结束后，Keras 会调用此方法清零状态，避免跨 epoch 混淆。

---

##### ✅ 6️⃣ 总结流程图

```
初始化：
  mse_sum = 0, total_samples = 0

每批次：
  计算当前 batch 的平方误差 → 累加到 mse_sum
  计算样本数 → 累加到 total_samples

最终结果：
  rmse = sqrt(mse_sum / total_samples)

epoch 结束：
  reset_state()
```

---

##### 📘 使用示例

```python
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=[RootMeanSquaredError()]
)
```

训练时，Keras 会自动：

* 在每个 batch 调用 `update_state`
* 在每个 epoch 结束时计算 `result`
* 在新 epoch 开始时调用 `reset_state`

---

#### Using callbacks

##### The EarlyStopping and ModelCheckpoint callbacks

In [None]:
callbacks_list = [
    keras.callbacks.EarlyStopping(
        monitor="accuracy",
        patience=1,
    ),
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path.keras",
        monitor="val_loss",
        save_best_only=True,
    ),
]
model = get_mnist_model()
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
model.fit(
    train_images,
    train_labels,
    epochs=10,
    callbacks=callbacks_list,
    validation_data=(val_images, val_labels),
)

In [None]:
model = keras.models.load_model("checkpoint_path.keras")

#### Writing your own callbacks

In [None]:
from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(
            range(len(self.per_batch_losses)),
            self.per_batch_losses,
            label="Training loss for each batch",
        )
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}", dpi=300)
        self.per_batch_losses = []

In [None]:
model = get_mnist_model()
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
model.fit(
    train_images,
    train_labels,
    epochs=10,
    callbacks=[LossHistory()],
    validation_data=(val_images, val_labels),
)

#### Monitoring and visualization with TensorBoard

In [None]:
model = get_mnist_model()
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

tensorboard = keras.callbacks.TensorBoard(
    log_dir="/full_path_to_your_log_dir",
)
model.fit(
    train_images,
    train_labels,
    epochs=10,
    validation_data=(val_images, val_labels),
    callbacks=[tensorboard],
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /full_path_to_your_log_dir

### Writing your own training and evaluation loops

#### Training vs. inference

#### Writing custom training step functions

##### A TensorFlow training step function

In [None]:
%%backend tensorflow
import tensorflow as tf

model = get_mnist_model()
loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()

def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply(gradients, model.trainable_weights)
    return loss

In [None]:
%%backend tensorflow
batch_size = 32
inputs = train_images[:batch_size]
targets = train_labels[:batch_size]
loss = train_step(inputs, targets)

###### 代码解释

这段代码展示了 **Keras + TensorFlow 自定义训练循环（custom training loop）** 的基本原理。
我们逐行来看它做了什么：

---

####### 🌱 1️⃣ 环境与准备

```python
%%backend tensorflow
import tensorflow as tf
```

* `%%backend tensorflow` 是 Jupyter 魔法命令，指定后端使用 TensorFlow。
* `import tensorflow as tf` 导入 TensorFlow 库。

---

####### 🌿 2️⃣ 创建模型、损失函数与优化器

```python
model = get_mnist_model()
loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()
```

* `get_mnist_model()` 是一个自定义函数（假设定义在前面），返回一个 MNIST 分类模型（比如卷积网络或全连接网络）。
* `loss_fn` 是稀疏分类交叉熵：

  * 适用于整数标签的多分类问题；
  * 输入标签如 `[3, 1, 7, ...]`；
  * 输出预测为 softmax 概率分布。
* `optimizer = Adam()` 用来根据梯度更新参数。

---

####### 🍃 3️⃣ 定义训练步骤函数

```python
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
```

######## ✳️ `tf.GradientTape()` 的作用

* 这是 TensorFlow 的自动微分机制；
* 它**记录**模型前向传播中涉及的所有可训练变量；
* 之后可以调用 `tape.gradient()` 自动计算这些变量的梯度。

流程：

1. **前向传播（forward pass）**
   `predictions = model(inputs, training=True)`
   模型在训练模式下执行前向计算；
2. **计算损失（loss）**
   `loss_fn(targets, predictions)` 计算预测与真实标签之间的误差。

---

####### 🍀 4️⃣ 计算梯度并更新权重

```python
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply(gradients, model.trainable_weights)
```

* `tape.gradient(loss, model.trainable_weights)`
  自动计算损失相对于每个可训练参数的梯度。
* `optimizer.apply(gradients, model.trainable_weights)`
  将梯度应用于参数（即执行一次反向传播 + 参数更新）。

---

####### 🌸 5️⃣ 返回损失

```python
    return loss
```

每次调用 `train_step()`，模型会完成：

1. 前向传播
2. 计算损失
3. 自动求梯度
4. 执行梯度下降（参数更新）
5. 返回当前 batch 的损失值

---

####### 🔁 6️⃣ 在训练循环中使用

你可以像这样调用它：

```python
for epoch in range(num_epochs):
    for step, (x_batch, y_batch) in enumerate(train_dataset):
        loss_value = train_step(x_batch, y_batch)
    print(f"Epoch {epoch+1}, Loss: {loss_value.numpy():.4f}")
```

这样你就实现了一个完全**自定义的训练循环**。
相比 `model.fit()`，这种方式能：

* 完全控制训练逻辑；
* 在训练过程中加入特殊操作（如梯度裁剪、对抗训练、可视化等）；
* 便于研究实验性网络结构。

---

######## ✅ 总结逻辑流程图：

```
train_step(inputs, targets):
    1. 开始记录计算图 (tf.GradientTape)
    2. 前向传播: predictions = model(inputs)
    3. 计算损失: loss = loss_fn(targets, predictions)
    4. 自动微分: gradients = tape.gradient(loss, model.weights)
    5. 应用梯度: optimizer.apply(gradients, model.weights)
    6. 返回当前 loss
```

---


##### A PyTorch training step function

In [None]:
%%backend torch
import torch

model = get_mnist_model()
loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()

def train_step(inputs, targets):
    predictions = model(inputs, training=True)
    loss = loss_fn(targets, predictions)
    loss.backward()
    gradients = [weight.value.grad for weight in model.trainable_weights]
    with torch.no_grad():
        optimizer.apply(gradients, model.trainable_weights)
    model.zero_grad()
    return loss

In [None]:
%%backend torch
batch_size = 32
inputs = train_images[:batch_size]
targets = train_labels[:batch_size]
loss = train_step(inputs, targets)

##### A JAX training step function

In [None]:
%%backend jax
model = get_mnist_model()
loss_fn = keras.losses.SparseCategoricalCrossentropy()

def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, inputs, targets
):
    outputs, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, inputs, training=True
    )
    loss = loss_fn(targets, outputs)
    return loss, non_trainable_variables

In [None]:
%%backend jax
import jax

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

In [None]:
%%backend jax
optimizer = keras.optimizers.Adam()
optimizer.build(model.trainable_variables)

def train_step(state, inputs, targets):
    (trainable_variables, non_trainable_variables, optimizer_variables) = state
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, inputs, targets
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

In [None]:
%%backend jax
batch_size = 32
inputs = train_images[:batch_size]
targets = train_labels[:batch_size]

trainable_variables = [v.value for v in model.trainable_variables]
non_trainable_variables = [v.value for v in model.non_trainable_variables]
optimizer_variables = [v.value for v in optimizer.variables]

state = (trainable_variables, non_trainable_variables, optimizer_variables)
loss, state = train_step(state, inputs, targets)

#### Low-level usage of metrics

In [None]:
from keras import ops

metric = keras.metrics.SparseCategoricalAccuracy()
targets = ops.array([0, 1, 2])
predictions = ops.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
metric.update_state(targets, predictions)
current_result = metric.result()
print(f"result: {current_result:.2f}")

In [None]:
values = ops.array([0, 1, 2, 3, 4])
mean_tracker = keras.metrics.Mean()
for value in values:
    mean_tracker.update_state(value)
print(f"Mean of values: {mean_tracker.result():.2f}")

In [None]:
metric = keras.metrics.SparseCategoricalAccuracy()
targets = ops.array([0, 1, 2])
predictions = ops.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

metric_variables = metric.variables
metric_variables = metric.stateless_update_state(
    metric_variables, targets, predictions
)
current_result = metric.stateless_result(metric_variables)
print(f"result: {current_result:.2f}")

metric_variables = metric.stateless_reset_state()

##### 代码解释

非常好 👍，这段代码演示了 **Keras 指标（Metric）对象的工作原理** ——
以 `SparseCategoricalAccuracy`（稀疏分类准确率）为例，展示了如何手动计算模型预测的准确率。
我们一步步来看 👇

---

###### 🧩 1️⃣ 导入 ops 并创建指标对象

```python
from keras import ops
metric = keras.metrics.SparseCategoricalAccuracy()
```

* `ops` 是 Keras 后端的通用操作接口（在不同后端如 TensorFlow、JAX、PyTorch 中保持一致的 API）。
* `keras.metrics.SparseCategoricalAccuracy()` 表示稀疏分类准确率：

  * 用于多分类任务；
  * 真实标签 (`y_true`) 是整数索引，如 `[0, 1, 2]`；
  * 预测结果 (`y_pred`) 是 softmax 概率向量，如 `[[0.8, 0.1, 0.1], ...]`。

---

###### 📊 2️⃣ 构造测试数据

```python
targets = ops.array([0, 1, 2])
predictions = ops.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
```

* `targets`：真实类别为

  ```
  样本1 → 类别0  
  样本2 → 类别1  
  样本3 → 类别2
  ```
* `predictions`：每个样本的预测概率向量

  ```
  样本1预测为类0（概率最高）  
  样本2预测为类1  
  样本3预测为类2
  ```

  因此，预测是完全正确的。

---

###### ⚙️ 3️⃣ 更新指标状态

```python
metric.update_state(targets, predictions)
```

####### 这一步的作用：

* `update_state()` 会比较真实标签与预测结果；
* 内部计算：

  1. `pred_class = argmax(predictions, axis=-1)` → `[0, 1, 2]`
  2. 与 `targets` `[0, 1, 2]` 比较；
  3. 匹配成功的样本数 = 3；
  4. 总样本数 = 3；
  5. 暂存到内部状态（`metric.total`, `metric.count`）。

---

###### 📈 4️⃣ 获取计算结果

```python
current_result = metric.result()
```

* `result()` 根据内部状态计算准确率：
  [
  \text{accuracy} = \frac{\text{正确样本数}}{\text{总样本数}} = \frac{3}{3} = 1.0
  ]

---

###### 🖨️ 5️⃣ 打印结果

```python
print(f"result: {current_result:.2f}")
```

输出：

```
result: 1.00
```

---

###### ✅ 总结流程

| 步骤 | 函数                                                   | 含义                   |
| -- | ---------------------------------------------------- | -------------------- |
| 1  | `metric = keras.metrics.SparseCategoricalAccuracy()` | 创建指标对象               |
| 2  | `metric.update_state(y_true, y_pred)`                | 更新内部状态（统计正确预测数与总样本数） |
| 3  | `metric.result()`                                    | 计算当前准确率              |
| 4  | `metric.reset_state()`（未展示）                          | 清空累计统计，准备下一轮计算       |

---

####### 💡拓展：如果预测不完全正确

假如：

```python
predictions = ops.array([[1, 0, 0], [0, 0, 1], [0, 1, 0]])
```

即只有第一个样本预测正确，
则：
[
\text{accuracy} = \frac{1}{3} \approx 0.33
]
输出会变为：

```
result: 0.33
```

---




#### Using fit() with a custom training loop
结合自定义训练循环，和keras内置循环的功能

##### Customizing fit() with TensorFlow

In [None]:
%%backend tensorflow
import keras
from keras import layers

loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")

class CustomModel(keras.Model):
    def train_step(self, data):
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = loss_fn(targets, predictions)
        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply(gradients, self.trainable_weights)

        loss_tracker.update_state(loss)
        return {"loss": loss_tracker.result()}

    @property
    def metrics(self):
        return [loss_tracker]

In [None]:
%%backend tensorflow
def get_custom_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = CustomModel(inputs, outputs)
    model.compile(optimizer=keras.optimizers.Adam())
    return model

In [None]:
%%backend tensorflow
model = get_custom_model()
model.fit(train_images, train_labels, epochs=3)

##### Customizing fit() with PyTorch

In [None]:
%%backend torch
import keras
from keras import layers

loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")

class CustomModel(keras.Model):
    def train_step(self, data):
        inputs, targets = data
        predictions = self(inputs, training=True)
        loss = loss_fn(targets, predictions)

        loss.backward()
        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        loss_tracker.update_state(loss)
        return {"loss": loss_tracker.result()}

    @property
    def metrics(self):
        return [loss_tracker]

In [None]:
%%backend torch
def get_custom_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = CustomModel(inputs, outputs)
    model.compile(optimizer=keras.optimizers.Adam())
    return model

In [None]:
%%backend torch
model = get_custom_model()
model.fit(train_images, train_labels, epochs=3)

##### Customizing fit() with JAX

In [None]:
%%backend jax
import keras
from keras import layers

loss_fn = keras.losses.SparseCategoricalCrossentropy()

class CustomModel(keras.Model):
    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        inputs,
        targets,
        training=False,
    ):
        predictions, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            inputs,
            training=training,
        )
        loss = loss_fn(targets, predictions)
        return loss, non_trainable_variables

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        inputs, targets = data

        grad_fn = jax.value_and_grad(
            self.compute_loss_and_updates, has_aux=True
        )

        (loss, non_trainable_variables), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            inputs,
            targets,
            training=True,
        )

        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        logs = {"loss": loss}
        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        )
        return logs, state

In [None]:
%%backend jax
def get_custom_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = CustomModel(inputs, outputs)
    model.compile(optimizer=keras.optimizers.Adam())
    return model

In [None]:
%%backend jax
model = get_custom_model()
model.fit(train_images, train_labels, epochs=3)

#### Handling metrics in a custom train_step()

##### train_step() metrics handling with TensorFlow

In [None]:
%%backend tensorflow
import keras
from keras import layers

class CustomModel(keras.Model):
    # 每个批次都会调用
    def train_step(self, data):
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = self.compute_loss(y=targets, y_pred=predictions)

        # 向后传播，每个批次都会更新权重
        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply(gradients, self.trainable_weights)

        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

In [None]:
%%backend tensorflow
def get_custom_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = CustomModel(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model

model = get_custom_model()
model.fit(train_images, train_labels, epochs=3)

###### 代码说明

这段代码定义了一个自定义的 Keras 模型，重写了训练步骤以实现更灵活的训练过程。让我逐部分详细解释：

####### 代码结构分析

######## 1. 基础设置和导入
```python
%%backend tensorflow
import keras
from keras import layers
```
- `%%backend tensorflow`: 在 Jupyter notebook 中设置后端为 TensorFlow
- 导入 Keras 框架和层模块

######## 2. 自定义模型类定义
```python
class CustomModel(keras.Model):
```
创建一个继承自 `keras.Model` 的自定义模型类

######## 3. 重写训练步骤方法
```python
def train_step(self, data):
```
重写 `train_step` 方法，这是训练过程中的核心方法，每个批次(batch)都会调用

######## 4. 数据解包和前向传播
```python
inputs, targets = data
with tf.GradientTape() as tape:
    predictions = self(inputs, training=True)
    loss = self.compute_loss(y=targets, y_pred=predictions)
```
- **数据解包**: `inputs, targets = data` - 将输入数据分为输入和标签
- **梯度记录**: `tf.GradientTape()` - 创建梯度记录上下文，跟踪所有可训练变量的操作
- **前向传播**: `self(inputs, training=True)` - 调用模型进行前向传播
- **损失计算**: `self.compute_loss()` - 使用模型编译时定义的损失函数计算损失

######## 5. 反向传播和权重更新
```python
gradients = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply(gradients, self.trainable_weights)
```
- **梯度计算**: `tape.gradient()` - 计算损失相对于所有可训练权重的梯度
- **权重更新**: `optimizer.apply()` - 使用优化器应用梯度更新权重

######## 6. 指标更新
```python
for metric in self.metrics:
    if metric.name == "loss":
        metric.update_state(loss)
    else:
        metric.update_state(targets, predictions)
```
- 遍历所有在模型编译时定义的指标
- 如果是损失指标，直接使用计算出的损失值
- 对于其他指标（如准确率），使用真实标签和预测值更新

######## 7. 返回训练指标
```python
return {m.name: m.result() for m in self.metrics}
```
返回一个字典，包含所有指标的当前值

####### 与标准训练的区别

**标准 Keras 训练**：
- 自动处理梯度计算和权重更新
- 内置损失计算和指标更新

**自定义训练的优势**：
1. **灵活性**: 可以自定义训练逻辑
2. **复杂场景**: 支持多任务学习、对抗训练等复杂场景
3. **调试**: 更容易调试和监控训练过程
4. **研究**: 适合研究和实验新的训练方法

####### 使用示例
```python
# 创建自定义模型实例
model = CustomModel()

# 编译模型（定义损失函数、优化器和指标）
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy', 'loss']
)

# 正常训练
model.fit(x_train, y_train, epochs=10)
```

这种自定义训练步骤的方式为复杂的训练场景提供了极大的灵活性。

自定义训练时，需要自己计算指标吗？
不用

##### train_step() metrics handling with PyTorch

In [None]:
%%backend torch
import keras
from keras import layers

class CustomModel(keras.Model):
    def train_step(self, data):
        inputs, targets = data
        predictions = self(inputs, training=True)
        loss = self.compute_loss(y=targets, y_pred=predictions)

        loss.backward()
        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(targets, predictions)

        return {m.name: m.result() for m in self.metrics}

In [None]:
%%backend torch
def get_custom_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = CustomModel(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model

model = get_custom_model()
model.fit(train_images, train_labels, epochs=3)

##### train_step() metrics handling with JAX

In [None]:
%%backend jax
import keras
from keras import layers

class CustomModel(keras.Model):
    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        inputs,
        targets,
        training=False,
    ):
        predictions, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            inputs,
            training=training,
        )
        loss = self.compute_loss(y=targets, y_pred=predictions)
        return loss, (predictions, non_trainable_variables)

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        inputs, targets = data

        grad_fn = jax.value_and_grad(
            self.compute_loss_and_updates, has_aux=True
        )

        (loss, (predictions, non_trainable_variables)), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            inputs,
            targets,
            training=True,
        )
        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        new_metrics_vars = []
        logs = {}
        for metric in self.metrics:
            num_prev = len(new_metrics_vars)
            num_current = len(metric.variables)
            current_vars = metrics_variables[num_prev : num_prev + num_current]
            if metric.name == "loss":
                current_vars = metric.stateless_update_state(current_vars, loss)
            else:
                current_vars = metric.stateless_update_state(
                    current_vars, targets, predictions
                )
            logs[metric.name] = metric.stateless_result(current_vars)
            new_metrics_vars += current_vars

        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            new_metrics_vars,
        )
        return logs, state