# SPDX-FileCopyrightText: Copyright (c) 2025 The Newton Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

###########################################################################
# Example MPM 2-Way Coupling
#
# 一个简单的场景，在平面上方生成十几个刚体形状。
# 这些形状通过XPBD求解器下落并发生碰撞。
# 演示了基本的构建器API和标准示例结构。
#
# Command: python -m newton.examples mpm_twoway_coupling
#
###########################################################################

from __future__ import annotations

import numpy as np
import warp as wp
import matplotlib.pyplot as plt

import newton
import newton.examples
from newton.solvers import SolverImplicitMPM


# ========== 沙子-刚体双向耦合：力计算内核 ==========
@wp.kernel
def compute_body_forces(
    dt: float,
    collider_ids: wp.array(dtype=int),
    collider_impulses: wp.array(dtype=wp.vec3),
    collider_impulse_pos: wp.array(dtype=wp.vec3),
    body_ids: wp.array(dtype=int),
    body_q: wp.array(dtype=wp.transform),
    body_com: wp.array(dtype=wp.vec3),
    body_f: wp.array(dtype=wp.spatial_vector),
):
    """计算沙子施加在刚体上的力。

    将施加在每个MPM网格节点上的冲量相加，并转换为主体质心处的力和扭矩。
    """

    i = wp.tid()

    cid = collider_ids[i]
    if cid >= 0 and cid < body_ids.shape[0]:
        body_index = body_ids[cid]
        if body_index == -1:
            return

        f_world = collider_impulses[i] / dt

        X_wb = body_q[body_index]
        X_com = body_com[body_index]
        r = collider_impulse_pos[i] - wp.transform_point(X_wb, X_com)
        wp.atomic_add(body_f, body_index, wp.spatial_vector(f_world, wp.cross(r, f_world)))


# ========== 沙子-刚体双向耦合：力减去内核 ==========
@wp.kernel
def subtract_body_force(
    dt: float,
    body_q: wp.array(dtype=wp.transform),
    body_qd: wp.array(dtype=wp.spatial_vector),
    body_f: wp.array(dtype=wp.spatial_vector),
    body_inv_inertia: wp.array(dtype=wp.mat33),
    body_inv_mass: wp.array(dtype=float),
    body_q_res: wp.array(dtype=wp.transform),
    body_qd_res: wp.array(dtype=wp.spatial_vector),
):
    """更新刚体速度，以移除上一步沙子施加的力。

    这是计算执行基于互补性的摩擦接触边界条件所需的总冲量所必需的。
    """

    body_id = wp.tid()

    # 移除先前施加的力
    f = body_f[body_id]
    delta_v = dt * body_inv_mass[body_id] * wp.spatial_top(f)
    r = wp.transform_get_rotation(body_q[body_id])

    delta_w = dt * wp.quat_rotate(r, body_inv_inertia[body_id] * wp.quat_rotate_inv(r, wp.spatial_bottom(f)))

    body_q_res[body_id] = body_q[body_id]
    body_qd_res[body_id] = body_qd[body_id] - wp.spatial_vector(delta_v, delta_w)


class Example:
    def __init__(self, viewer):
        # ========== 仿真时间控制 ==========
        self.fps = 100
        self.frame_dt = 1.0 / self.fps
        self.sim_time = 0.0
        self.sim_substeps = 4
        self.sim_dt = self.frame_dt / self.sim_substeps

        self.viewer = viewer

        # ========== 刚体模型构建 ==========
        builder = newton.ModelBuilder()
        builder.default_shape_cfg.mu = 0.5
        self._emit_rigid_bodies(builder)

        # 添加地面
        builder.add_ground_plane()

        # ========== 沙子模型构建 ==========
        sand_builder = newton.ModelBuilder()
        voxel_size = 0.05  # 5 cm
        self._emit_particles(sand_builder, voxel_size)

        # ========== 物理模型最终化 ==========
        self.model = builder.finalize()
        self.sand_model = sand_builder.finalize()

        # 基础粒子材料参数
        self.sand_model.particle_mu = 0.48
        self.sand_model.particle_ke = 1.0e15

        # ========== MPM求解器配置 ==========
        mpm_options = SolverImplicitMPM.Options()
        mpm_options.voxel_size = voxel_size
        mpm_options.tolerance = 1.0e-6
        mpm_options.grid_type = "fixed"  # 固定网格以便进行图捕捉
        mpm_options.grid_padding = 50
        mpm_options.max_active_cell_count = 1 << 15

        mpm_options.strain_basis = "P0"
        mpm_options.max_iterations = 50
        mpm_options.critical_fraction = 0.0

        mpm_model = SolverImplicitMPM.Model(self.sand_model, mpm_options)
        # 从刚体模型而不是沙子模型中读取碰撞体
        mpm_model.setup_collider(model=self.model)

        self.mpm_solver = SolverImplicitMPM(mpm_model, mpm_options)

        # ========== 刚体求解器配置 ==========
        self.solver = newton.solvers.SolverXPBD(self.model)

        # ========== 仿真状态初始化 ==========
        self.state_0 = self.model.state()
        self.state_1 = self.model.state()

        self.sand_state_0 = self.sand_model.state()
        self.mpm_solver.enrich_state(self.sand_state_0)

        self.control = self.model.control()
        self.contacts = self.model.collide(self.state_0)

        # ========== 可视化设置 ==========
        self.viewer.set_model(self.model)
        if isinstance(self.viewer, newton.viewer.ViewerGL):
            self.viewer.register_ui_callback(self.render_ui, position="side")
        self.viewer.show_particles = True
        self.show_impulses = True

        # 初始化正向运动学
        newton.eval_fk(self.model, self.model.joint_q, self.model.joint_qd, self.state_0)

        # ========== 双向耦合数据缓冲区 ==========
        # 用于跟踪双向耦合力的附加缓冲区
        max_nodes = 1 << 20
        self.collider_impulses = wp.zeros(max_nodes, dtype=wp.vec3, device=self.model.device)
        self.collider_impulse_pos = wp.zeros(max_nodes, dtype=wp.vec3, device=self.model.device)
        self.collider_impulse_ids = wp.full(max_nodes, value=-1, dtype=int, device=self.model.device)
        self.collect_collider_impulses()

        # 从碰撞体索引到刚体索引的映射
        self.collider_body_id = mpm_model.collider.collider_body_index

        # 沙子施加到刚体上的每个刚体的力和扭矩
        self.body_sand_forces = wp.zeros_like(self.state_0.body_f)

        self.particle_render_colors = wp.full(
            self.sand_model.particle_count, value=wp.vec3(0.7, 0.6, 0.4), dtype=wp.vec3, device=self.sand_model.device
        )

        # ========== F-t 曲线绘制设置 ==========
        self.force_history = []
        self.time_history = []
        plt.ion()  # 开启交互模式
        self.fig, self.ax = plt.subplots()
        (self.line,) = self.ax.plot(self.time_history, self.force_history)
        self.ax.set_xlabel("Time (s)")
        self.ax.set_ylabel("Force (N)")
        self.ax.set_title("Interaction Force vs. Time")
        self.ax.grid(True)

        self.capture()

    def capture(self):
        """设置CUDA图优化以加速仿真"""
        if wp.get_device().is_cuda:
            with wp.ScopedCapture() as capture:
                self.simulate()
            self.graph = capture.graph
        else:
            self.graph = None

    def simulate(self):
        """执行一个仿真子步"""
        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()

            # 计算沙子施加在刚体上的力
            wp.launch(
                compute_body_forces,
                dim=self.collider_impulse_ids.shape[0],
                inputs=[
                    self.frame_dt,
                    self.collider_impulse_ids,
                    self.collider_impulses,
                    self.collider_impulse_pos,
                    self.collider_body_id,
                    self.state_0.body_q,
                    self.model.body_com,
                    self.state_0.body_f,
                ],
            )
            # 保存施加的力以便稍后减去
            self.body_sand_forces.assign(self.state_0.body_f)

            # 将力施加到模型上
            self.viewer.apply_forces(self.state_0)

            # 刚体仿真步进
            self.contacts = self.model.collide(self.state_0)
            self.solver.step(self.state_0, self.state_1, self.control, self.contacts, self.sim_dt)

            # 交换状态
            self.state_0, self.state_1 = self.state_1, self.state_0

        self.simulate_sand()

    def collect_collider_impulses(self):
        """从MPM求解器收集碰撞冲量"""
        collider_impulses, collider_impulse_pos, collider_impulse_ids = self.mpm_solver.collect_collider_impulses(
            self.sand_state_0
        )
        self.collider_impulse_ids.fill_(-1)
        n_colliders = min(collider_impulses.shape[0], self.collider_impulses.shape[0])
        self.collider_impulses[:n_colliders].assign(collider_impulses[:n_colliders])
        self.collider_impulse_pos[:n_colliders].assign(collider_impulse_pos[:n_colliders])
        self.collider_impulse_ids[:n_colliders].assign(collider_impulse_ids[:n_colliders])

    def simulate_sand(self):
        """执行沙子仿真步进"""
        # 从刚体速度中减去先前施加的冲量
        if self.sand_state_0.body_q is not None:
            wp.launch(
                subtract_body_force,
                dim=self.sand_state_0.body_q.shape,
                inputs=[
                    self.frame_dt,
                    self.state_0.body_q,
                    self.state_0.body_qd,
                    self.body_sand_forces,
                    self.model.body_inv_inertia,
                    self.model.body_inv_mass,
                    self.sand_state_0.body_q,
                    self.sand_state_0.body_qd,
                ],
            )

        # MPM求解器步进
        self.mpm_solver.step(self.sand_state_0, self.sand_state_0, contacts=None, control=None, dt=self.frame_dt)

        # 保存冲量以便施加回刚体
        self.collect_collider_impulses()

    def step(self):
        """主仿真步骤：每帧调用一次"""
        if self.graph:
            wp.capture_launch(self.graph)
        else:
            self.simulate()

        self.sim_time += self.frame_dt

        # ========== 计算并记录总作用力 ==========
        # body_sand_forces 包含线性和角力，我们只关心线性力（前3个分量）
        force_vec = self.body_sand_forces.numpy()[self.body_block][:3]
        # 这个三维向量的欧几里得范数（也就是它的模长或大小）
        force1, force2, force3 = force_vec
        force_magnitude = np.linalg.norm(force3)
        self.force_history.append(force_magnitude)
        self.time_history.append(self.sim_time)

    def test(self):
        """测试接口"""
        newton.examples.test_body_state(
            self.model,
            self.state_0,
            "all bodies are above the sand",
            lambda q, qd: q[2] > 0.45,
        )
        voxel_size = self.mpm_solver.mpm_model.voxel_size
        newton.examples.test_particle_state(
            self.sand_state_0,
            "all particles are above the ground",
            lambda q, qd: q[2] > -voxel_size,
        )

    def render(self):
        """渲染函数：显示沙子粒子和刚体"""
        self.viewer.begin_frame(self.sim_time)
        self.viewer.log_state(self.state_0)
        self.viewer.log_contacts(self.contacts, self.state_0)

        self.viewer.log_points(
            "/sand",
            points=self.sand_state_0.particle_q,
            radii=self.sand_model.particle_radius,
            colors=self.particle_render_colors,
            hidden=not self.viewer.show_particles,
        )

        if self.show_impulses:
            impulses, pos, _cid = self.mpm_solver.collect_collider_impulses(self.sand_state_0)
            self.viewer.log_lines(
                "/impulses",
                starts=pos,
                ends=pos + impulses,
                colors=wp.full(pos.shape[0], value=wp.vec3(1.0, 0.0, 0.0), dtype=wp.vec3),
            )
        else:
            self.viewer.log_lines("/impulses", None, None, None)

        self.viewer.end_frame()

        # ========== 终端打印与F-t曲线更新 ==========
        if self.force_history:
            current_force = self.force_history[-1]
            # 使用 \r 和 end="" 实现单行刷新
            print(f"\rInteraction Force: {current_force:8.2f} N", end="")

            # 更新曲线数据
            self.line.set_xdata(self.time_history)
            self.line.set_ydata(self.force_history)
            self.ax.relim()
            self.ax.autoscale_view()
            self.fig.canvas.draw()
            self.fig.canvas.flush_events()

    def render_ui(self, imgui):
        """渲染UI界面"""
        _changed, self.show_impulses = imgui.checkbox("Show Impulses", self.show_impulses)

    def _emit_rigid_bodies(self, builder: newton.ModelBuilder):
        """生成刚体"""
        # 掉落形状的z高度
        drop_z = 2.0

        # 生成一个盒子
        boxes = [(0.45, 0.35, 0.25)]  # (hx, hy, hz)
        for box in boxes:
            (hx, hy, hz) = box

            pz = drop_z
            self.body_block = builder.add_body(
                xform=wp.transform(p=wp.vec3(0.0, 0.0, pz), q=wp.normalize(wp.quatf(0.0, 0.0, 0.0, 1.0))),
                mass=75.0,
            )
            builder.add_shape_box(self.body_block, hx=float(hx), hy=float(hy), hz=float(hz))

    def _emit_particles(self, sand_builder: newton.ModelBuilder, voxel_size: float):
        """生成沙子粒子"""
        # ------------------------------------------
        # 在地面上方添加沙床 (2m x 2m x 0.5m)
        # ------------------------------------------

        particles_per_cell = 3.0
        density = 2500.0

        bed_lo = np.array([-1.0, -1.0, 0.0])
        bed_hi = np.array([1.0, 1.0, 0.5])
        bed_res = np.array(np.ceil(particles_per_cell * (bed_hi - bed_lo) / voxel_size), dtype=int)

        cell_size = (bed_hi - bed_lo) / bed_res
        cell_volume = np.prod(cell_size)
        radius = float(np.max(cell_size) * 0.5)
        mass = float(np.prod(cell_volume) * density)

        sand_builder.add_particle_grid(
            pos=wp.vec3(bed_lo),
            rot=wp.quat_identity(),
            vel=wp.vec3(0.0),
            dim_x=bed_res[0] + 1,
            dim_y=bed_res[1] + 1,
            dim_z=bed_res[2] + 1,
            cell_x=cell_size[0],
            cell_y=cell_size[1],
            cell_z=cell_size[2],
            mass=mass,
            jitter=2.0 * radius,
            radius_mean=radius,
        )


if __name__ == "__main__":
    # 解析参数并初始化查看器
    viewer, args = newton.examples.init()

    # 创建示例并运行
    example = Example(viewer)

    newton.examples.run(example, args)
