In [2]:
from manim import *

In [3]:
%%manim -qh -v WARNING NeuralNetwork3D

import numpy as np

class NeuralNetwork3D(ThreeDScene):
    def construct(self):
        # ======================
        # 1. 标题
        # ======================
        title = Text("3D Neural Network Visualization", font_size=40)
        self.add_fixed_in_frame_mobjects(title)
        title.to_edge(UP)
        self.play(Write(title))

        # ======================
        # 2. 相机设置
        # ======================
        self.set_camera_orientation(
            phi=65 * DEGREES,
            theta=-45 * DEGREES,
            zoom=1.2
        )

        axes = ThreeDAxes()
        self.play(Create(axes))

        # ======================
        # 3. 网络结构参数
        # ======================
        layer_sizes = [4, 6, 3]  # 输入层-隐藏层-输出层
        layer_x = [-4, 0, 4]     # 每一层在 X 轴的位置
        layers = []

        # ======================
        # 4. 创建神经元
        # ======================
        for size, x in zip(layer_sizes, layer_x):
            neurons = VGroup()
            for i in range(size):
                neuron = Sphere(
                    radius=0.25,
                    resolution=(16, 16)
                )
                neuron.set_color(BLUE)
                neuron.move_to(np.array([x, (i - size / 2) * 0.8, 0]))
                neurons.add(neuron)
            layers.append(neurons)
            self.play(FadeIn(neurons))

        # ======================
        # 5. 创建连接（权重）
        # ======================
        connections = VGroup()
        for l1, l2 in zip(layers[:-1], layers[1:]):
            for n1 in l1:
                for n2 in l2:
                    line = Line3D(
                        n1.get_center(),
                        n2.get_center(),
                        stroke_width=1,
                        color=GRAY
                    )
                    connections.add(line)

        self.play(Create(connections))

        # ======================
        # 6. 前向传播动画
        # ======================
        for layer in layers:
            self.play(
                *[neuron.animate.set_color(YELLOW) for neuron in layer],
                run_time=0.6
            )
            self.play(
                *[neuron.animate.set_color(BLUE) for neuron in layer],
                run_time=0.4
            )

        # ======================
        # 7. 输出层高亮
        # ======================
        self.play(
            *[neuron.animate.set_color(RED) for neuron in layers[-1]]
        )

        output_text = Text("Output Layer", font_size=30)
        self.add_fixed_in_frame_mobjects(output_text)
        output_text.to_corner(DR)
        self.play(Write(output_text))

        # ======================
        # 8. 相机绕行
        # ======================
        self.begin_ambient_camera_rotation(rate=0.3)
        self.wait(4)
        self.stop_ambient_camera_rotation()

        self.wait(2)


                                                                                                                